diff --git a/README.md b/README.md index c307ad5..c1c5680 100755 --- a/README.md +++ b/README.md @@ -32,38 +32,35 @@ git clone https://github.com/pliang279/LG-FedAvg.git We run FedAvg and LG-FedAvg experiments on MNIST ([link](http://yann.lecun.com/exdb/mnist/)) and CIFAR10 ([link](https://www.cs.toronto.edu/~kriz/cifar.html)). See our paper for a description how we process and partition the data for federated learning experiments. +## FedAvg -## LG-FedAvg - -Results can be reproduced using the following: +Results can be reproduced running the following: #### MNIST -> python main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 1500 --lr 0.05 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 +> python main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --results_save run1 #### CIFAR10 -> python main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 - +> python main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --results_save run1 -## FedAvg +## LG-FedAvg -Results can be reproduced using the following: +Results can be reproduced by first running the above commands for FedAvg and then running the following: #### MNIST -> python main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1500 --lr 0.05 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 10 +> python main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 --results_save run1 --load_fed best_400.pt #### CIFAR10 -> python main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 50 - +> python main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 --results_save run1 --load_fed best_1200.pt ## MTL -Results can be reproduced using the following: +Results can be reproduced running the following: #### MNIST -> python main_mtl.py --dataset mnist --model mlp --num_classes 10 --epochs 500 --lr 0.05 --num_users 100 --frac 1 --local_ep 1 --local_bs 10 --num_layers_keep 5 +> python main_mtl.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 5 --results_save run1 #### CIFAR10 -> python main_mtl.py --dataset cifar10 --model cnn --num_classes 10 --epochs 1500 --lr 0.1 --num_users 100 --frac 1 --local_ep 1 --local_bs 50 --num_layers_keep 5 +> python main_mtl.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 5 --results_save run1 If you use this code, please cite our paper: diff --git a/main_lg.py b/main_lg.py index ae704fc..26c3d79 100644 --- a/main_lg.py +++ b/main_lg.py @@ -13,7 +13,7 @@ from utils.options import args_parser from utils.train_utils import get_data, get_model from models.Update import LocalUpdate -from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all +from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all, test_img_local import pdb @@ -26,7 +26,7 @@ args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save) assert(len(args.load_fed) > 0) - base_save_dir = os.path.join(base_dir, 'lg2/{}'.format(args.load_fed)) + base_save_dir = os.path.join(base_dir, 'lg/{}'.format(args.load_fed)) if not os.path.exists(base_save_dir): os.makedirs(base_save_dir, exist_ok=True) @@ -61,22 +61,25 @@ for user in range(args.num_users): net_local_list.append(copy.deepcopy(net_glob)) - acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test) acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test) + acc_test_local_list, _ = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True) + acc_test_local = acc_test_local_list.mean() # training results_save_path = os.path.join(base_save_dir, 'results.csv') + results_columns = ['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'] loss_train = [] - net_best = None - best_acc = np.ones(args.num_users) * acc_test_local + best_iter = -1 + best_acc_local = -1 + best_acc_list = acc_test_local_list best_net_list = copy.deepcopy(net_local_list) + fina_net_list = copy.deepcopy(net_local_list) - lr = args.lr results = [] - results.append(np.array([-1, acc_test_local, acc_test_avg, best_acc.mean(), None, None])) - print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( - -1, -1, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) + results.append(np.array([-1, acc_test_local, acc_test_avg, acc_test_local, None, None])) + print('Round {:3d}, Acc (local): {:.2f}, Acc (avg): {:.2f}, Acc (local-best): {:.2f}'.format( + -1, acc_test_local, acc_test_avg, acc_test_local)) for iter in range(args.epochs): w_glob = {} @@ -89,64 +92,75 @@ local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx]) net_local = net_local_list[idx] - w_local, loss = local.train(net=net_local.to(args.device), lr=lr) + w_local, loss = local.train(net=net_local.to(args.device), lr=args.lr) loss_locals.append(copy.deepcopy(loss)) - modules_glob = set([x.split('.')[0] for x in w_keys_epoch]) - modules_all = net_local.__dict__['_modules'] - # sum up weights if len(w_glob) == 0: - w_glob = copy.deepcopy(net_glob.state_dict()) + w_glob = copy.deepcopy(w_local) else: for k in w_keys_epoch: w_glob[k] += w_local[k] + loss_avg = sum(loss_locals) / len(loss_locals) + loss_train.append(loss_avg) # get weighted average for global weights for k in w_keys_epoch: w_glob[k] = torch.div(w_glob[k], m) - # copy weight to the global model (not really necessary) - net_glob.load_state_dict(w_glob) - # copy weights to each local model for idx in range(args.num_users): net_local = net_local_list[idx] w_local = net_local.state_dict() for k in w_keys_epoch: w_local[k] = w_glob[k] - net_local.load_state_dict(w_local) - loss_avg = sum(loss_locals) / len(loss_locals) - loss_train.append(loss_avg) - - # eval - acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True) - + # find best local models from after round + acc_test_local_list, _ = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True) for user in range(args.num_users): - if acc_test_local[user] > best_acc[user]: - best_acc[user] = acc_test_local[user] + if acc_test_local_list[user] >= best_acc_list[user]: + best_acc_list[user] = acc_test_local_list[user] best_net_list[user] = copy.deepcopy(net_local_list[user]) - model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) - torch.save(best_net_list[user].state_dict(), model_save_path) + # average best models for local test + acc_test_avg, loss_test_avg, net_glob = test_img_avg_all(net_glob, best_net_list, args, dataset_test, return_net=True) + + # average global layers of the best local models + best_net_list_avg = copy.deepcopy(best_net_list) + w_glob = net_glob.state_dict() + for net_local in best_net_list_avg: + w_local = net_local.state_dict() + for k in w_keys_epoch: + w_local[k] = w_glob[k] + net_local.load_state_dict(w_local) + + acc_test_local, _ = test_img_local_all(best_net_list_avg, args, dataset_test, dict_users_test) + if acc_test_local > best_acc_local: + best_acc_local = acc_test_local + best_iter = iter + final_net_list = copy.deepcopy(best_net_list_avg) - acc_test_local, loss_test_local = test_img_local_all(best_net_list, args, dataset_test, dict_users_test) - acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, best_net_list, args, dataset_test) - print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( - iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) + print('Round {:3d}, Acc (local): {:.2f}, Acc (avg): {:.2f}, Acc (local-best): {:.2f}'.format( + iter, acc_test_local, acc_test_avg, best_acc_local)) - results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc.mean(), None, None])) + results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc_local, None, None])) final_results = np.array(results) - final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results = pd.DataFrame(final_results, columns=results_columns) final_results.to_csv(results_save_path, index=False) - acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(best_net_list, args, dataset_test) - print('Best model, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj)) + acc_test_local, loss_test_local = test_img_local_all(final_net_list, args, dataset_test, dict_users_test) + acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, final_net_list, args, dataset_test) + acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(final_net_list, args, dataset_test) + print('Best model, acc (local): {}, acc (avg): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format( + acc_test_local, acc_test_avg, acc_test_ens_avg, acc_test_ens_maj)) - results.append(np.array(['Final', None, None, best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj])) + results.append(np.array([best_iter, None, acc_test_avg, acc_test_local, acc_test_ens_avg, acc_test_ens_maj])) final_results = np.array(results) - final_results = pd.DataFrame(final_results, - columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) - final_results.to_csv(results_save_path, index=False) \ No newline at end of file + final_results = pd.DataFrame(final_results, columns=results_columns) + final_results.to_csv(results_save_path, index=False) + + # save models + for user, net_local in enumerate(final_net_list): + model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) + torch.save(net_local.state_dict(), model_save_path) \ No newline at end of file diff --git a/main_lg_backup.py b/main_lg_backup.py index c1a0583..799b3c8 100644 --- a/main_lg_backup.py +++ b/main_lg_backup.py @@ -13,7 +13,7 @@ from utils.options import args_parser from utils.train_utils import get_data, get_model from models.Update import LocalUpdate -from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all +from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all, test_img_local import pdb @@ -26,7 +26,7 @@ args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save) assert(len(args.load_fed) > 0) - base_save_dir = os.path.join(base_dir, 'lg/{}'.format(args.load_fed)) + base_save_dir = os.path.join(base_dir, 'lg3/{}'.format(args.load_fed)) if not os.path.exists(base_save_dir): os.makedirs(base_save_dir, exist_ok=True) @@ -55,102 +55,111 @@ percentage_param = 100 * float(num_param_glob) / num_param_local print('# Params: {} (local), {} (global); Percentage {:.2f} ({}/{})'.format( num_param_local, num_param_glob, percentage_param, num_param_glob, num_param_local)) + # generate list of local models for each user net_local_list = [] - for user_ix in range(args.num_users): + for user in range(args.num_users): net_local_list.append(copy.deepcopy(net_glob)) + acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test) + acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test) + # training results_save_path = os.path.join(base_save_dir, 'results.csv') + results_columns = ['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'] loss_train = [] - net_best = None - best_loss = None - best_acc = None - best_epoch = None + best_iter = -1 + best_acc_local = -1 + best_acc_list = np.ones(args.num_users) * acc_test_local + best_net_list = copy.deepcopy(net_local_list) + fina_net_list = copy.deepcopy(net_local_list) - lr = args.lr results = [] + results.append(np.array([-1, acc_test_local, acc_test_avg, acc_test_local, None, None])) + print('Round {:3d}, Acc (local): {:.2f}, Acc (avg): {:.2f}, Acc (local-best): {:.2f}'.format( + -1, acc_test_local, acc_test_avg, acc_test_local)) for iter in range(args.epochs): w_glob = {} loss_locals = [] m = max(int(args.frac * args.num_users), 1) idxs_users = np.random.choice(range(args.num_users), m, replace=False) - # w_keys_epoch = net_glob.state_dict().keys() if (iter + 1) % 25 == 0 else w_glob_keys w_keys_epoch = w_glob_keys - if args.verbose: - print("Round {}: lr: {:.6f}, {}".format(iter, lr, idxs_users)) for idx in idxs_users: local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx]) net_local = net_local_list[idx] - w_local, loss = local.train(net=net_local.to(args.device), lr=lr) + w_local, loss = local.train(net=net_local.to(args.device), lr=args.lr) loss_locals.append(copy.deepcopy(loss)) - modules_glob = set([x.split('.')[0] for x in w_keys_epoch]) - modules_all = net_local.__dict__['_modules'] - # sum up weights if len(w_glob) == 0: - w_glob = copy.deepcopy(net_glob.state_dict()) + w_glob = copy.deepcopy(w_local) else: for k in w_keys_epoch: w_glob[k] += w_local[k] - - if (iter+1) % int(args.num_users * args.frac): - lr *= args.lr_decay + loss_avg = sum(loss_locals) / len(loss_locals) + loss_train.append(loss_avg) # get weighted average for global weights for k in w_keys_epoch: w_glob[k] = torch.div(w_glob[k], m) - # copy weight to the global model (not really necessary) - net_glob.load_state_dict(w_glob) - # copy weights to each local model for idx in range(args.num_users): net_local = net_local_list[idx] w_local = net_local.state_dict() for k in w_keys_epoch: w_local[k] = w_glob[k] - net_local.load_state_dict(w_local) - loss_avg = sum(loss_locals) / len(loss_locals) - loss_train.append(loss_avg) + # find best local models from after round + acc_test_local_list, _ = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True) + for user in range(args.num_users): + if acc_test_local_list[user] > best_acc_list[user]: + best_acc_list[user] = acc_test_local_list[user] + best_net_list[user] = copy.deepcopy(net_local_list[user]) - # eval - acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test) - acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test) - print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( - iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) + # average best models for local test + acc_test_avg, loss_test_avg, net_glob = test_img_avg_all(net_glob, best_net_list, args, dataset_test, return_net=True) - if best_acc is None or acc_test_local > best_acc: - best_acc = acc_test_local - best_epoch = iter + # average global layers of the best local models + best_net_list_avg = copy.deepcopy(best_net_list) + w_glob = net_glob.state_dict() + for net_local in best_net_list_avg: + w_local = net_local.state_dict() + for k in w_keys_epoch: + w_local[k] = w_glob[k] + net_local.load_state_dict(w_local) + + acc_test_local, loss_test_local = test_img_local_all(best_net_list_avg, args, dataset_test, dict_users_test) + if acc_test_local > best_acc_local: + best_acc_local = acc_test_local + best_iter = iter + final_net_list = copy.deepcopy(best_net_list_avg) - for user in range(args.num_users): - model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) - torch.save(net_local_list[user].state_dict(), model_save_path) + print('Round {:3d}, Acc (local): {:.2f}, Acc (avg): {:.2f}, Acc (local-best): {:.2f}'.format( + iter, acc_test_local, acc_test_avg, best_acc_local)) - results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc, None, None])) + results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc_local, None, None])) final_results = np.array(results) - final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results = pd.DataFrame(final_results, columns=results_columns) final_results.to_csv(results_save_path, index=False) - for user in range(args.num_users): - model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) - - net_local = net_local_list[user] - net_local.load_state_dict(torch.load(model_save_path)) - acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(net_local_list, args, dataset_test) - - print('Best model, iter: {}, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_epoch, best_acc, acc_test_ens_avg, acc_test_ens_maj)) + acc_test_local, loss_test_local = test_img_local_all(final_net_list, args, dataset_test, dict_users_test) + acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, final_net_list, args, dataset_test) + acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(final_net_list, args, dataset_test) + print('Best model, acc (local): {}, acc (avg): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format( + acc_test_local, acc_test_avg, acc_test_ens_avg, acc_test_ens_maj)) - results.append(np.array(['Final', None, None, best_acc, acc_test_ens_avg, acc_test_ens_maj])) + results.append(np.array([best_iter, None, acc_test_avg, acc_test_local, acc_test_ens_avg, acc_test_ens_maj])) final_results = np.array(results) - final_results = pd.DataFrame(final_results, - columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results = pd.DataFrame(final_results, columns=results_columns) final_results.to_csv(results_save_path, index=False) + + # save models + for user, net_local in enumerate(final_net_list): + model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) + torch.save(final_net_list.state_dict(), model_save_path) \ No newline at end of file diff --git a/main_mtl.py b/main_mtl.py index c3ac4dc..a0e5c87 100644 --- a/main_mtl.py +++ b/main_mtl.py @@ -100,8 +100,7 @@ idxs_users = np.random.choice(range(args.num_users), m, replace=False) W = torch.zeros((d, m)).cuda() - # for idx, user in enumerate(idxs_users): - for idx, user in enumerate(range(10)): + for idx, user in enumerate(idxs_users): W_local = [net_local_list[user].state_dict()[key].flatten() for key in w_glob_keys] W_local = torch.cat(W_local) W[:, idx] = W_local diff --git a/models/test.py b/models/test.py index 2df3fcc..69f216d 100644 --- a/models/test.py +++ b/models/test.py @@ -66,7 +66,7 @@ def test_img_local(net_g, dataset, args, user_idx=-1, idxs=None): test_loss = 0 correct = 0 # data_loader = DataLoader(dataset, batch_size=args.bs) - data_loader = DataLoader(DatasetSplit(dataset, idxs), batch_size=args.bs, shuffle=True) + data_loader = DataLoader(DatasetSplit(dataset, idxs), batch_size=args.bs, shuffle=False) l = len(data_loader) for idx, (data, target) in enumerate(data_loader): diff --git a/scripts/main_lg_cifar10.sh b/scripts/main_lg_cifar10.sh index a09a0f3..995b1c5 100755 --- a/scripts/main_lg_cifar10.sh +++ b/scripts/main_lg_cifar10.sh @@ -11,7 +11,7 @@ done for FED_MODEL in 800 1000 1200 1400 1600 1800; do for RUN in 1 2 3 4 5; do - python3 main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \ + python3 main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 500 --lr 0.1 \ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 \ --results_save run${RUN} --load_fed best_${FED_MODEL}.pt done diff --git a/scripts/main_lg_mnist.sh b/scripts/main_lg_mnist.sh index 65344b8..0dc2ce6 100755 --- a/scripts/main_lg_mnist.sh +++ b/scripts/main_lg_mnist.sh @@ -11,7 +11,7 @@ done for FED_MODEL in 400 500 600 700 800; do for RUN in 1 2 3 4 5; do - python3 main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 \ + python3 main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 500 --lr 0.05 \ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 \ --results_save run${RUN} --load_fed best_${FED_MODEL}.pt done diff --git a/test_local_numusers.py b/test_local_numusers.py index 5647914..9e2b81e 100644 --- a/test_local_numusers.py +++ b/test_local_numusers.py @@ -28,13 +28,13 @@ args = args_parser() args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') args.num_classes = 10 -args.epochs = 100 +args.epochs = 200 if args.dataset == "mnist": args.local_bs = 10 else: args.local_bs = 50 -early_stopping = 20 +early_stopping = 100 args.shard_per_user = 2 for num_users in [5, 10, 20, 50, 100, 200, 500, 1000]: