Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
terranceliu committed Apr 30, 2020
1 parent 39def97 commit 8ee6a68
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 110 deletions.
25 changes: 11 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,38 +32,35 @@ git clone https://github.com/pliang279/LG-FedAvg.git

We run FedAvg and LG-FedAvg experiments on MNIST ([link](http://yann.lecun.com/exdb/mnist/)) and CIFAR10 ([link](https://www.cs.toronto.edu/~kriz/cifar.html)). See our paper for a description how we process and partition the data for federated learning experiments.

## FedAvg

## LG-FedAvg

Results can be reproduced using the following:
Results can be reproduced running the following:

#### MNIST
> python main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 1500 --lr 0.05 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3
> python main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --results_save run1
#### CIFAR10
> python main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2
> python main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --results_save run1
## FedAvg
## LG-FedAvg

Results can be reproduced using the following:
Results can be reproduced by first running the above commands for FedAvg and then running the following:

#### MNIST
> python main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1500 --lr 0.05 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 10
> python main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 --results_save run1 --load_fed best_400.pt
#### CIFAR10
> python main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 50
> python main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 --results_save run1 --load_fed best_1200.pt
## MTL

Results can be reproduced using the following:
Results can be reproduced running the following:

#### MNIST
> python main_mtl.py --dataset mnist --model mlp --num_classes 10 --epochs 500 --lr 0.05 --num_users 100 --frac 1 --local_ep 1 --local_bs 10 --num_layers_keep 5
> python main_mtl.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 5 --results_save run1
#### CIFAR10
> python main_mtl.py --dataset cifar10 --model cnn --num_classes 10 --epochs 1500 --lr 0.1 --num_users 100 --frac 1 --local_ep 1 --local_bs 50 --num_layers_keep 5
> python main_mtl.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 5 --results_save run1

If you use this code, please cite our paper:
Expand Down
94 changes: 54 additions & 40 deletions main_lg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from utils.options import args_parser
from utils.train_utils import get_data, get_model
from models.Update import LocalUpdate
from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all
from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all, test_img_local

import pdb

Expand All @@ -26,7 +26,7 @@
args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)

assert(len(args.load_fed) > 0)
base_save_dir = os.path.join(base_dir, 'lg2/{}'.format(args.load_fed))
base_save_dir = os.path.join(base_dir, 'lg/{}'.format(args.load_fed))
if not os.path.exists(base_save_dir):
os.makedirs(base_save_dir, exist_ok=True)

Expand Down Expand Up @@ -61,22 +61,25 @@
for user in range(args.num_users):
net_local_list.append(copy.deepcopy(net_glob))

acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test)
acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test)
acc_test_local_list, _ = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True)
acc_test_local = acc_test_local_list.mean()

# training
results_save_path = os.path.join(base_save_dir, 'results.csv')
results_columns = ['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']

loss_train = []
net_best = None
best_acc = np.ones(args.num_users) * acc_test_local
best_iter = -1
best_acc_local = -1
best_acc_list = acc_test_local_list
best_net_list = copy.deepcopy(net_local_list)
fina_net_list = copy.deepcopy(net_local_list)

lr = args.lr
results = []
results.append(np.array([-1, acc_test_local, acc_test_avg, best_acc.mean(), None, None]))
print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format(
-1, -1, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg))
results.append(np.array([-1, acc_test_local, acc_test_avg, acc_test_local, None, None]))
print('Round {:3d}, Acc (local): {:.2f}, Acc (avg): {:.2f}, Acc (local-best): {:.2f}'.format(
-1, acc_test_local, acc_test_avg, acc_test_local))

for iter in range(args.epochs):
w_glob = {}
Expand All @@ -89,64 +92,75 @@
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx])
net_local = net_local_list[idx]

w_local, loss = local.train(net=net_local.to(args.device), lr=lr)
w_local, loss = local.train(net=net_local.to(args.device), lr=args.lr)
loss_locals.append(copy.deepcopy(loss))

modules_glob = set([x.split('.')[0] for x in w_keys_epoch])
modules_all = net_local.__dict__['_modules']

# sum up weights
if len(w_glob) == 0:
w_glob = copy.deepcopy(net_glob.state_dict())
w_glob = copy.deepcopy(w_local)
else:
for k in w_keys_epoch:
w_glob[k] += w_local[k]
loss_avg = sum(loss_locals) / len(loss_locals)
loss_train.append(loss_avg)

# get weighted average for global weights
for k in w_keys_epoch:
w_glob[k] = torch.div(w_glob[k], m)

# copy weight to the global model (not really necessary)
net_glob.load_state_dict(w_glob)

# copy weights to each local model
for idx in range(args.num_users):
net_local = net_local_list[idx]
w_local = net_local.state_dict()
for k in w_keys_epoch:
w_local[k] = w_glob[k]

net_local.load_state_dict(w_local)

loss_avg = sum(loss_locals) / len(loss_locals)
loss_train.append(loss_avg)

# eval
acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True)

# find best local models from after round
acc_test_local_list, _ = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True)
for user in range(args.num_users):
if acc_test_local[user] > best_acc[user]:
best_acc[user] = acc_test_local[user]
if acc_test_local_list[user] >= best_acc_list[user]:
best_acc_list[user] = acc_test_local_list[user]
best_net_list[user] = copy.deepcopy(net_local_list[user])

model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user))
torch.save(best_net_list[user].state_dict(), model_save_path)
# average best models for local test
acc_test_avg, loss_test_avg, net_glob = test_img_avg_all(net_glob, best_net_list, args, dataset_test, return_net=True)

# average global layers of the best local models
best_net_list_avg = copy.deepcopy(best_net_list)
w_glob = net_glob.state_dict()
for net_local in best_net_list_avg:
w_local = net_local.state_dict()
for k in w_keys_epoch:
w_local[k] = w_glob[k]
net_local.load_state_dict(w_local)

acc_test_local, _ = test_img_local_all(best_net_list_avg, args, dataset_test, dict_users_test)
if acc_test_local > best_acc_local:
best_acc_local = acc_test_local
best_iter = iter
final_net_list = copy.deepcopy(best_net_list_avg)

acc_test_local, loss_test_local = test_img_local_all(best_net_list, args, dataset_test, dict_users_test)
acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, best_net_list, args, dataset_test)
print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format(
iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg))
print('Round {:3d}, Acc (local): {:.2f}, Acc (avg): {:.2f}, Acc (local-best): {:.2f}'.format(
iter, acc_test_local, acc_test_avg, best_acc_local))

results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc.mean(), None, None]))
results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc_local, None, None]))
final_results = np.array(results)
final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'])
final_results = pd.DataFrame(final_results, columns=results_columns)
final_results.to_csv(results_save_path, index=False)

acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(best_net_list, args, dataset_test)
print('Best model, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj))
acc_test_local, loss_test_local = test_img_local_all(final_net_list, args, dataset_test, dict_users_test)
acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, final_net_list, args, dataset_test)
acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(final_net_list, args, dataset_test)
print('Best model, acc (local): {}, acc (avg): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(
acc_test_local, acc_test_avg, acc_test_ens_avg, acc_test_ens_maj))

results.append(np.array(['Final', None, None, best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj]))
results.append(np.array([best_iter, None, acc_test_avg, acc_test_local, acc_test_ens_avg, acc_test_ens_maj]))
final_results = np.array(results)
final_results = pd.DataFrame(final_results,
columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'])
final_results.to_csv(results_save_path, index=False)
final_results = pd.DataFrame(final_results, columns=results_columns)
final_results.to_csv(results_save_path, index=False)

# save models
for user, net_local in enumerate(final_net_list):
model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user))
torch.save(net_local.state_dict(), model_save_path)
Loading

0 comments on commit 8ee6a68

Please sign in to comment.