diff --git a/.gitignore b/.gitignore index 0d20b64..3ac6e56 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,140 @@ -*.pyc +.idea/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/get_results.ipynb b/get_results.ipynb new file mode 100644 index 0000000..d511a88 --- /dev/null +++ b/get_results.ipynb @@ -0,0 +1,385 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "iid = True\n", + "num_users = 100\n", + "frac = 0.1\n", + "local_ep = 1\n", + "\n", + "dataset = 'mnist'\n", + "# dataset = 'cifar10'\n", + "# dataset = 'cifar100'\n", + "\n", + "shard_per_user = 10\n", + "\n", + "if dataset == 'mnist':\n", + " model = 'mlp'\n", + " rd_lg = 100\n", + " rd_fed = 800 + int(rd_lg*0.15)\n", + "elif dataset == 'cifar10':\n", + " model = 'cnn'\n", + " rd_lg = 100\n", + " rd_fed = 1800 + int(rd_lg*0.04)\n", + "elif dataset == 'cifar100':\n", + " model = 'cnn'\n", + " rd_fed = 1800\n", + " rd_lg = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/'.format(\n", + " dataset, model, iid, num_users, frac, local_ep, shard_per_user)\n", + "runs = os.listdir(base_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "acc_fed = np.zeros(len(runs))\n", + "acc_local_localtest = np.zeros(len(runs))\n", + "acc_local_newtest_avg = np.zeros(len(runs))\n", + "acc_local_newtest_ens = np.zeros(len(runs))\n", + "lg_metrics = {}\n", + "\n", + "for idx, run in enumerate(runs):\n", + " # FedAvg\n", + " base_dir_fed = os.path.join(base_dir, \"{}/fed\".format(run))\n", + " results_path_fed = os.path.join(base_dir_fed, \"results.csv\")\n", + " df_fed = pd.read_csv(results_path_fed)\n", + "\n", + " acc_fed[idx] = df_fed.loc[rd_fed - 1]['best_acc']\n", + " \n", + " # LocalOnly\n", + " base_dir_local = os.path.join(base_dir, \"{}/local\".format(run))\n", + " results_path_local = os.path.join(base_dir_local, \"results.csv\")\n", + " df_local = pd.read_csv(results_path_local)\n", + "\n", + " acc_local_localtest[idx] = df_local.loc[0]['acc_test_local']\n", + " acc_local_newtest_avg[idx] = df_local.loc[0]['acc_test_avg']\n", + " if 'acc_test_ens' in df_local.columns:\n", + " acc_local_newtest_ens[idx] = df_local.loc[0]['acc_test_ens']\n", + " else:\n", + " acc_local_newtest_ens[idx] = df_local.loc[0]['acc_test_ens_avg']\n", + " \n", + " # LGFed\n", + " base_dir_lg = os.path.join(base_dir, \"{}/lg/\".format(run))\n", + " lg_runs = os.listdir(base_dir_lg)\n", + " for lg_run in lg_runs:\n", + " results_path_lg = os.path.join(base_dir_lg, \"{}/results.csv\".format(lg_run))\n", + " df_lg = pd.read_csv(results_path_lg)\n", + " \n", + " load_fed = int(re.split('best_|.pt', lg_run)[1])\n", + " if load_fed not in lg_metrics.keys():\n", + " lg_metrics[load_fed] = {'acc_local': np.zeros(len(runs)),\n", + " 'acc_avg': np.zeros(len(runs)),\n", + " 'acc_ens': np.zeros(len(runs))}\n", + " \n", + " x = df_lg.loc[rd_lg]['best_acc_local']\n", + " lg_metrics[load_fed]['acc_local'][idx] = x\n", + " idx_acc_local = df_lg[df_lg['best_acc_local'] == x].index[0]\n", + " lg_metrics[load_fed]['acc_avg'][idx] = df_lg.loc[idx_acc_local]['acc_test_avg']\n", + " if 'acc_test_ens' in df_lg.columns:\n", + " lg_metrics[load_fed]['acc_ens'][idx] = df_lg['acc_test_ens'].values[-1]\n", + " else:\n", + " lg_metrics[load_fed]['acc_ens'][idx] = df_lg['acc_test_ens_avg'].values[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "columns = [\"Run\", \"Local Test\", \"New Test (avg)\", \"New Test (ens)\", \"FedAvg Rounds\", \"LG Rounds\"]\n", + "results = []" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "localonly:\t 88.03 +- 0.37\n", + "localonly_avg:\t 86.24 +- 0.87\n", + "localonly_ens:\t 91.15 +- 0.27\n" + ] + } + ], + "source": [ + "str_acc_local_localtest = \"{:.2f} +- {:.2f}\".format(acc_local_localtest.mean(), acc_local_localtest.std())\n", + "str_acc_local_newtest_avg = \"{:.2f} +- {:.2f}\".format(acc_local_newtest_avg.mean(), acc_local_newtest_avg.std())\n", + "str_acc_local_newtest_ens = \"{:.2f} +- {:.2f}\".format(acc_local_newtest_ens.mean(), acc_local_newtest_ens.std())\n", + "\n", + "print(\"localonly:\\t\", str_acc_local_localtest)\n", + "print(\"localonly_avg:\\t\", str_acc_local_newtest_avg)\n", + "print(\"localonly_ens:\\t\", str_acc_local_newtest_ens)\n", + "\n", + "results.append([\"LocalOnly\", str_acc_local_localtest, str_acc_local_newtest_avg, str_acc_local_newtest_ens, 0, 0])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "300\n", + "acc_local:\t97.45 +- 0.12\n", + "acc_avg:\t97.50 +- 0.10\n", + "acc_ens:\t97.49 +- 0.09\n", + "400\n", + "acc_local:\t97.59 +- 0.08\n", + "acc_avg:\t97.61 +- 0.08\n", + "acc_ens:\t97.62 +- 0.08\n", + "500\n", + "acc_local:\t97.78 +- 0.13\n", + "acc_avg:\t97.82 +- 0.14\n", + "acc_ens:\t97.82 +- 0.13\n", + "600\n", + "acc_local:\t97.84 +- 0.10\n", + "acc_avg:\t97.86 +- 0.08\n", + "acc_ens:\t97.87 +- 0.10\n", + "700\n", + "acc_local:\t97.85 +- 0.09\n", + "acc_avg:\t97.88 +- 0.09\n", + "acc_ens:\t97.89 +- 0.09\n", + "800\n", + "acc_local:\t97.91 +- 0.10\n", + "acc_avg:\t97.93 +- 0.07\n", + "acc_ens:\t97.93 +- 0.07\n" + ] + } + ], + "source": [ + "for lg_run in sorted(lg_metrics.keys()):\n", + " x = [\"LG-FedAvg\"]\n", + " print(lg_run)\n", + " for array in ['acc_local', 'acc_avg', 'acc_ens']:\n", + " mean = lg_metrics[lg_run][array].mean()\n", + " std = lg_metrics[lg_run][array].std()\n", + " str_acc = \"{:.2f} +- {:.2f}\".format(mean, std)\n", + " print(\"{}:\\t{}\".format(array, str_acc))\n", + " \n", + " x.append(str_acc)\n", + " x.append(lg_run)\n", + " x.append(rd_lg)\n", + " results.append(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fed:\t 97.93 +- 0.08\n" + ] + } + ], + "source": [ + "str_acc_fed = \"{:.2f} +- {:.2f}\".format(acc_fed.mean(), acc_fed.std())\n", + "print(\"fed:\\t\", str_acc_fed)\n", + "results.append([\"FedAvg\", str_acc_fed, str_acc_fed, str_acc_fed, rd_fed, 0])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Local TestNew Test (avg)New Test (ens)FedAvg RoundsLG Rounds
Run
LocalOnly88.03 +- 0.3786.24 +- 0.8791.15 +- 0.2700
LG-FedAvg97.45 +- 0.1297.50 +- 0.1097.49 +- 0.09300100
LG-FedAvg97.59 +- 0.0897.61 +- 0.0897.62 +- 0.08400100
LG-FedAvg97.78 +- 0.1397.82 +- 0.1497.82 +- 0.13500100
LG-FedAvg97.84 +- 0.1097.86 +- 0.0897.87 +- 0.10600100
LG-FedAvg97.85 +- 0.0997.88 +- 0.0997.89 +- 0.09700100
LG-FedAvg97.91 +- 0.1097.93 +- 0.0797.93 +- 0.07800100
FedAvg97.93 +- 0.0897.93 +- 0.0897.93 +- 0.088150
\n", + "
" + ], + "text/plain": [ + " Local Test New Test (avg) New Test (ens) FedAvg Rounds \\\n", + "Run \n", + "LocalOnly 88.03 +- 0.37 86.24 +- 0.87 91.15 +- 0.27 0 \n", + "LG-FedAvg 97.45 +- 0.12 97.50 +- 0.10 97.49 +- 0.09 300 \n", + "LG-FedAvg 97.59 +- 0.08 97.61 +- 0.08 97.62 +- 0.08 400 \n", + "LG-FedAvg 97.78 +- 0.13 97.82 +- 0.14 97.82 +- 0.13 500 \n", + "LG-FedAvg 97.84 +- 0.10 97.86 +- 0.08 97.87 +- 0.10 600 \n", + "LG-FedAvg 97.85 +- 0.09 97.88 +- 0.09 97.89 +- 0.09 700 \n", + "LG-FedAvg 97.91 +- 0.10 97.93 +- 0.07 97.93 +- 0.07 800 \n", + "FedAvg 97.93 +- 0.08 97.93 +- 0.08 97.93 +- 0.08 815 \n", + "\n", + " LG Rounds \n", + "Run \n", + "LocalOnly 0 \n", + "LG-FedAvg 100 \n", + "LG-FedAvg 100 \n", + "LG-FedAvg 100 \n", + "LG-FedAvg 100 \n", + "LG-FedAvg 100 \n", + "LG-FedAvg 100 \n", + "FedAvg 0 " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(results, columns=columns).set_index(\"Run\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/main_fed.py b/main_fed.py index 428da6a..013a3b8 100644 --- a/main_fed.py +++ b/main_fed.py @@ -3,7 +3,9 @@ # Python version: 3.6 import copy +import pickle import numpy as np +import pandas as pd import torch from utils.options import args_parser @@ -19,21 +21,22 @@ args = args_parser() args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') - base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/{}/'.format( - args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.results_save) + base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format( + args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save) if not os.path.exists(os.path.join(base_dir, 'fed')): os.makedirs(os.path.join(base_dir, 'fed'), exist_ok=True) - dataset_train, dataset_test, dict_users_train, _, rand_set_all = get_data(args) - rand_save_path = os.path.join(base_dir, 'randset.npy') - np.save(rand_save_path, rand_set_all) + dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args) + dict_save_path = os.path.join(base_dir, 'dict_users.pkl') + with open(dict_save_path, 'wb') as handle: + pickle.dump((dict_users_train, dict_users_test), handle) # build model net_glob = get_model(args) net_glob.train() # training - results_save_path = os.path.join(base_dir, 'fed/results.npy') + results_save_path = os.path.join(base_dir, 'fed/results.csv') loss_train = [] net_best = None @@ -83,15 +86,25 @@ print('Round {:3d}, Average loss {:.3f}, Test loss {:.3f}, Test accuracy: {:.2f}'.format( iter, loss_avg, loss_test, acc_test)) - results.append(np.array([iter, loss_avg, loss_test, acc_test])) - final_results = np.array(results) - np.save(results_save_path, final_results) - model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter)) if best_acc is None or acc_test > best_acc: + net_best = copy.deepcopy(net_glob) best_acc = acc_test best_epoch = iter - if iter > args.start_saving: - torch.save(net_glob.state_dict(), model_save_path) + + # if (iter + 1) > args.start_saving: + # model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter + 1)) + # torch.save(net_glob.state_dict(), model_save_path) + + results.append(np.array([iter, loss_avg, loss_test, acc_test, best_acc])) + final_results = np.array(results) + final_results = pd.DataFrame(final_results, columns=['epoch', 'loss_avg', 'loss_test', 'acc_test', 'best_acc']) + final_results.to_csv(results_save_path, index=False) + + if (iter + 1) % 50 == 0: + best_save_path = os.path.join(base_dir, 'fed/best_{}.pt'.format(iter + 1)) + model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter + 1)) + torch.save(net_best.state_dict(), best_save_path) + torch.save(net_glob.state_dict(), model_save_path) print('Best model, iter: {}, acc: {}'.format(best_epoch, best_acc)) \ No newline at end of file diff --git a/main_lg.py b/main_lg.py index be975b7..ae704fc 100644 --- a/main_lg.py +++ b/main_lg.py @@ -4,15 +4,16 @@ import copy import os +import pickle import itertools +import pandas as pd import numpy as np import torch -from torch import nn from utils.options import args_parser from utils.train_utils import get_data, get_model from models.Update import LocalUpdate -from models.test import test_img, test_img_local +from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all import pdb @@ -21,26 +22,24 @@ args = args_parser() args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') - base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/{}/'.format( - args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.results_save) - if not os.path.exists(os.path.join(base_dir, 'lg')): - os.makedirs(os.path.join(base_dir, 'lg'), exist_ok=True) + base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format( + args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save) - rand_set_all = [] - if len(args.load_fed) > 0: - rand_save_path = os.path.join(base_dir, 'randset.npy') - rand_set_all= np.load(rand_save_path) + assert(len(args.load_fed) > 0) + base_save_dir = os.path.join(base_dir, 'lg2/{}'.format(args.load_fed)) + if not os.path.exists(base_save_dir): + os.makedirs(base_save_dir, exist_ok=True) - dataset_train, dataset_test, dict_users_train, dict_users_test, rand_set_all = get_data(args, rand_set_all=rand_set_all) - rand_save_path = os.path.join(base_dir, 'randset.npy') - np.save(rand_save_path, rand_set_all) + dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args) + dict_save_path = os.path.join(base_dir, 'dict_users.pkl') + with open(dict_save_path, 'rb') as handle: + dict_users_train, dict_users_test = pickle.load(handle) # build model net_glob = get_model(args) net_glob.train() - if len(args.load_fed) > 0: - fed_model_path = os.path.join(base_dir, 'fed/{}'.format(args.load_fed)) + fed_model_path = os.path.join(base_dir, 'fed/{}'.format(args.load_fed)) net_glob.load_state_dict(torch.load(fed_model_path)) total_num_layers = len(net_glob.weight_keys) @@ -56,96 +55,36 @@ percentage_param = 100 * float(num_param_glob) / num_param_local print('# Params: {} (local), {} (global); Percentage {:.2f} ({}/{})'.format( num_param_local, num_param_glob, percentage_param, num_param_glob, num_param_local)) + # generate list of local models for each user net_local_list = [] - for user_ix in range(args.num_users): + for user in range(args.num_users): net_local_list.append(copy.deepcopy(net_glob)) - criterion = nn.CrossEntropyLoss() - - - def test_img_ensemble_all(): - probs_all = [] - preds_all = [] - for idx in range(args.num_users): - net_local = net_local_list[idx] - net_local.eval() - # _, _, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx) - acc, loss, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx) - # print('Local model: {}, loss: {}, acc: {}'.format(idx, loss, acc)) - probs_all.append(probs.detach()) - - preds = probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1) - preds_all.append(preds) - - labels = np.array(dataset_test.targets) - preds_probs = torch.mean(torch.stack(probs_all), dim=0) - - # ensemble metrics - preds_avg = preds_probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1) - loss_test = criterion(preds_probs, torch.tensor(labels).to(args.device)).item() - acc_test = (preds_avg == labels).mean() * 100 - - return loss_test, acc_test - - def test_img_local_all(): - acc_test_local = 0 - loss_test_local = 0 - for idx in range(args.num_users): - net_local = net_local_list[idx] - net_local.eval() - a, b = test_img_local(net_local, dataset_test, args, user_idx=idx, idxs=dict_users_test[idx]) - - acc_test_local += a - loss_test_local += b - acc_test_local /= args.num_users - loss_test_local /= args.num_users - - return acc_test_local, loss_test_local - - def test_img_avg_all(): - net_glob_temp = copy.deepcopy(net_glob) - w_keys_epoch = net_glob.state_dict().keys() - w_glob_temp = {} - for idx in range(args.num_users): - net_local = net_local_list[idx] - w_local = net_local.state_dict() - - if len(w_glob_temp) == 0: - w_glob_temp = copy.deepcopy(w_local) - else: - for k in w_keys_epoch: - w_glob_temp[k] += w_local[k] - - for k in w_keys_epoch: - w_glob_temp[k] = torch.div(w_glob_temp[k], args.num_users) - net_glob_temp.load_state_dict(w_glob_temp) - acc_test_avg, loss_test_avg = test_img(net_glob_temp, dataset_test, args) - - return acc_test_avg, loss_test_avg + acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test) + acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test) # training - results_save_path = os.path.join(base_dir, 'lg/results.npy') + results_save_path = os.path.join(base_save_dir, 'results.csv') loss_train = [] net_best = None - best_loss = None - best_acc = None - best_epoch = None + best_acc = np.ones(args.num_users) * acc_test_local + best_net_list = copy.deepcopy(net_local_list) lr = args.lr results = [] + results.append(np.array([-1, acc_test_local, acc_test_avg, best_acc.mean(), None, None])) + print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( + -1, -1, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) for iter in range(args.epochs): w_glob = {} loss_locals = [] m = max(int(args.frac * args.num_users), 1) idxs_users = np.random.choice(range(args.num_users), m, replace=False) - # w_keys_epoch = net_glob.state_dict().keys() if (iter + 1) % 25 == 0 else w_glob_keys w_keys_epoch = w_glob_keys - if args.verbose: - print("Round {}: lr: {:.6f}, {}".format(iter, lr, idxs_users)) for idx in idxs_users: local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx]) net_local = net_local_list[idx] @@ -163,9 +102,6 @@ def test_img_avg_all(): for k in w_keys_epoch: w_glob[k] += w_local[k] - if (iter+1) % int(args.num_users * args.frac): - lr *= args.lr_decay - # get weighted average for global weights for k in w_keys_epoch: w_glob[k] = torch.div(w_glob[k], m) @@ -186,29 +122,31 @@ def test_img_avg_all(): loss_train.append(loss_avg) # eval - acc_test_local, loss_test_local = test_img_local_all() - acc_test_avg, loss_test_avg = test_img_avg_all() + acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True) + + for user in range(args.num_users): + if acc_test_local[user] > best_acc[user]: + best_acc[user] = acc_test_local[user] + best_net_list[user] = copy.deepcopy(net_local_list[user]) + + model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) + torch.save(best_net_list[user].state_dict(), model_save_path) + acc_test_local, loss_test_local = test_img_local_all(best_net_list, args, dataset_test, dict_users_test) + acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, best_net_list, args, dataset_test) print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) - results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg])) + results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc.mean(), None, None])) final_results = np.array(results) - np.save(results_save_path, final_results) - - if best_acc is None or acc_test_local > best_acc: - best_acc = acc_test_local - best_epoch = iter - - for user in range(args.num_users): - model_save_path = os.path.join(base_dir, 'lg/model_user{}.pt'.format(user)) - torch.save(net_local_list[user].state_dict(), model_save_path) - - for user in range(args.num_users): - model_save_path = os.path.join(base_dir, 'lg/model_user{}.pt'.format(user)) + final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results.to_csv(results_save_path, index=False) - net_local = net_local_list[idx] - net_local.load_state_dict(torch.load(model_save_path)) - loss_test, acc_test = test_img_ensemble_all() + acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(best_net_list, args, dataset_test) + print('Best model, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj)) - print('Best model, iter: {}, acc: {}, acc (ens): {}'.format(best_epoch, best_acc, acc_test)) \ No newline at end of file + results.append(np.array(['Final', None, None, best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj])) + final_results = np.array(results) + final_results = pd.DataFrame(final_results, + columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results.to_csv(results_save_path, index=False) \ No newline at end of file diff --git a/main_lg_backup.py b/main_lg_backup.py new file mode 100644 index 0000000..c1a0583 --- /dev/null +++ b/main_lg_backup.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Python version: 3.6 + +import copy +import os +import pickle +import itertools +import pandas as pd +import numpy as np +import torch + +from utils.options import args_parser +from utils.train_utils import get_data, get_model +from models.Update import LocalUpdate +from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all + +import pdb + +if __name__ == '__main__': + # parse args + args = args_parser() + args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') + + base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format( + args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save) + + assert(len(args.load_fed) > 0) + base_save_dir = os.path.join(base_dir, 'lg/{}'.format(args.load_fed)) + if not os.path.exists(base_save_dir): + os.makedirs(base_save_dir, exist_ok=True) + + dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args) + dict_save_path = os.path.join(base_dir, 'dict_users.pkl') + with open(dict_save_path, 'rb') as handle: + dict_users_train, dict_users_test = pickle.load(handle) + + # build model + net_glob = get_model(args) + net_glob.train() + + fed_model_path = os.path.join(base_dir, 'fed/{}'.format(args.load_fed)) + net_glob.load_state_dict(torch.load(fed_model_path)) + + total_num_layers = len(net_glob.weight_keys) + w_glob_keys = net_glob.weight_keys[total_num_layers - args.num_layers_keep:] + w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) + + num_param_glob = 0 + num_param_local = 0 + for key in net_glob.state_dict().keys(): + num_param_local += net_glob.state_dict()[key].numel() + if key in w_glob_keys: + num_param_glob += net_glob.state_dict()[key].numel() + percentage_param = 100 * float(num_param_glob) / num_param_local + print('# Params: {} (local), {} (global); Percentage {:.2f} ({}/{})'.format( + num_param_local, num_param_glob, percentage_param, num_param_glob, num_param_local)) + # generate list of local models for each user + net_local_list = [] + for user_ix in range(args.num_users): + net_local_list.append(copy.deepcopy(net_glob)) + + # training + results_save_path = os.path.join(base_save_dir, 'results.csv') + + loss_train = [] + net_best = None + best_loss = None + best_acc = None + best_epoch = None + + lr = args.lr + results = [] + + for iter in range(args.epochs): + w_glob = {} + loss_locals = [] + m = max(int(args.frac * args.num_users), 1) + idxs_users = np.random.choice(range(args.num_users), m, replace=False) + # w_keys_epoch = net_glob.state_dict().keys() if (iter + 1) % 25 == 0 else w_glob_keys + w_keys_epoch = w_glob_keys + + if args.verbose: + print("Round {}: lr: {:.6f}, {}".format(iter, lr, idxs_users)) + for idx in idxs_users: + local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx]) + net_local = net_local_list[idx] + + w_local, loss = local.train(net=net_local.to(args.device), lr=lr) + loss_locals.append(copy.deepcopy(loss)) + + modules_glob = set([x.split('.')[0] for x in w_keys_epoch]) + modules_all = net_local.__dict__['_modules'] + + # sum up weights + if len(w_glob) == 0: + w_glob = copy.deepcopy(net_glob.state_dict()) + else: + for k in w_keys_epoch: + w_glob[k] += w_local[k] + + if (iter+1) % int(args.num_users * args.frac): + lr *= args.lr_decay + + # get weighted average for global weights + for k in w_keys_epoch: + w_glob[k] = torch.div(w_glob[k], m) + + # copy weight to the global model (not really necessary) + net_glob.load_state_dict(w_glob) + + # copy weights to each local model + for idx in range(args.num_users): + net_local = net_local_list[idx] + w_local = net_local.state_dict() + for k in w_keys_epoch: + w_local[k] = w_glob[k] + + net_local.load_state_dict(w_local) + + loss_avg = sum(loss_locals) / len(loss_locals) + loss_train.append(loss_avg) + + # eval + acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test) + acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test) + print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( + iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) + + if best_acc is None or acc_test_local > best_acc: + best_acc = acc_test_local + best_epoch = iter + + for user in range(args.num_users): + model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) + torch.save(net_local_list[user].state_dict(), model_save_path) + + results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc, None, None])) + final_results = np.array(results) + final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results.to_csv(results_save_path, index=False) + + for user in range(args.num_users): + model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) + + net_local = net_local_list[user] + net_local.load_state_dict(torch.load(model_save_path)) + acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(net_local_list, args, dataset_test) + + print('Best model, iter: {}, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_epoch, best_acc, acc_test_ens_avg, acc_test_ens_maj)) + + results.append(np.array(['Final', None, None, best_acc, acc_test_ens_avg, acc_test_ens_maj])) + final_results = np.array(results) + final_results = pd.DataFrame(final_results, + columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results.to_csv(results_save_path, index=False) diff --git a/main_local.py b/main_local.py new file mode 100644 index 0000000..792414b --- /dev/null +++ b/main_local.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Python version: 3.6 + +import copy +import os +import pickle +import pandas as pd +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader + +from utils.options import args_parser +from utils.train_utils import get_data, get_model +from models.Update import DatasetSplit +from models.test import test_img_local, test_img_local_all, test_img_avg_all, test_img_ensemble_all + +import pdb + +if __name__ == '__main__': + # parse args + args = args_parser() + args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') + + base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format( + args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save) + if not os.path.exists(os.path.join(base_dir, 'local')): + os.makedirs(os.path.join(base_dir, 'local'), exist_ok=True) + + dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args) + dict_save_path = os.path.join(base_dir, 'dict_users.pkl') + with open(dict_save_path, 'rb') as handle: + dict_users_train, dict_users_test = pickle.load(handle) + + # build model + net_glob = get_model(args) + net_glob.train() + + net_local_list = [] + for user_ix in range(args.num_users): + net_local_list.append(copy.deepcopy(net_glob)) + + # training + results_save_path = os.path.join(base_dir, 'local/results.csv') + + loss_train = [] + net_best = None + best_loss = None + best_acc = None + best_epoch = None + + lr = args.lr + results = [] + + criterion = nn.CrossEntropyLoss() + + for user, net_local in enumerate(net_local_list): + model_save_path = os.path.join(base_dir, 'local/model_user{}.pt'.format(user)) + net_best = None + best_acc = None + + ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users_train[user]), batch_size=args.local_bs, shuffle=True) + optimizer = torch.optim.SGD(net_local.parameters(), lr=lr, momentum=0.5) + for iter in range(args.epochs): + for batch_idx, (images, labels) in enumerate(ldr_train): + images, labels = images.to(args.device), labels.to(args.device) + net_local.zero_grad() + log_probs = net_local(images) + + loss = criterion(log_probs, labels) + loss.backward() + optimizer.step() + + acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user]) + if best_acc is None or acc_test > best_acc: + best_acc = acc_test + net_best = copy.deepcopy(net_local) + # torch.save(net_local_list[user].state_dict(), model_save_path) + + print('User {}, Epoch {}, Acc {:.2f}'.format(user, iter, acc_test)) + + if iter > 50 and acc_test >= 99: + break + + net_local_list[user] = net_best + + acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test) + acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test) + acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(net_local_list, args, dataset_test) + + print('Final: acc: {:.2f}, acc (avg): {:.2f}, acc (ens,avg): {:.2f}, acc (ens,maj): {:.2f}'.format(acc_test_local, acc_test_avg, acc_test_ens_avg, acc_test_ens_maj)) + + final_results = np.array([[acc_test_local, acc_test_avg, acc_test_ens_avg, acc_test_ens_maj]]) + final_results = pd.DataFrame(final_results, columns=['acc_test_local', 'acc_test_avg', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results.to_csv(results_save_path, index=False) diff --git a/main_mtl.py b/main_mtl.py index 26f521c..c3ac4dc 100644 --- a/main_mtl.py +++ b/main_mtl.py @@ -7,19 +7,21 @@ # import matplotlib.pyplot as plt import copy import os +import pickle import itertools import numpy as np +import pandas as pd +from tqdm import tqdm from scipy.stats import mode from torchvision import datasets, transforms, models import torch from torch import nn -from utils.sampling import mnist_iid, mnist_noniid, cifar10_iid, cifar10_noniid +from utils.train_utils import get_model, get_data from utils.options import args_parser from models.Update import LocalUpdateMTL -from models.Nets import MLP, CNNMnist, CNNCifar, ResnetCifar -from models.Fed import FedAvg -from models.test import test_img, test_img_local +from models.test import test_img, test_img_local, test_img_local_all, test_img_avg_all, test_img_ensemble_all + import pdb @@ -28,93 +30,24 @@ args = args_parser() args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') - trans_mnist = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - - if args.model == 'resnet': - trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.Resize([256,256]), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - trans_cifar_val = transforms.Compose([transforms.Resize([256,256]), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - else: - trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - trans_cifar_val = transforms.Compose([transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) + base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format( + args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save) - # load dataset and split users - if args.dataset == 'mnist': - dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist) - dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist) - # sample users - if args.iid: - dict_users_train = mnist_iid(dataset_train, args.num_users) - dict_users_test = mnist_iid(dataset_test, args.num_users) - else: - dict_users_train, rand_set_all = mnist_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=300, train=True) - dict_users_test, _ = mnist_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50, train=False, rand_set_all=rand_set_all) - save_path = './save/{}/randset_lg_{}_iid{}_num{}_C{}.pt'.format( - args.results_save, args.dataset, args.iid, args.num_users, args.frac) - np.save(save_path, rand_set_all) + base_save_dir = os.path.join(base_dir, 'mtl') + if not os.path.exists(base_save_dir): + os.makedirs(base_save_dir, exist_ok=True) - elif args.dataset == 'cifar10': - dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar_train) - dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar_val) - if args.iid: - dict_users_train = cifar10_iid(dataset_train, args.num_users) - dict_users_test = cifar10_iid(dataset_test, args.num_users) - else: - dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250, train=True) - dict_users_test, _ = cifar10_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50, train=False, rand_set_all=rand_set_all) - - elif args.dataset == 'cifar100': - dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar_train) - dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar_val) - if args.iid: - dict_users_train = cifar10_iid(dataset_train, args.num_users) - dict_users_test = cifar10_iid(dataset_test, args.num_users) - else: - dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=1000, num_imgs=50, train=True) - dict_users_test, _ = cifar10_noniid(dataset_test, args.num_users, num_shards=1000, num_imgs=10, train=False, rand_set_all=rand_set_all) - else: - exit('Error: unrecognized dataset') - img_size = dataset_train[0][0].shape + dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args) + dict_save_path = os.path.join(base_dir, 'dict_users.pkl') + with open(dict_save_path, 'rb') as handle: + dict_users_train, dict_users_test = pickle.load(handle) # build model - if args.model == 'cnn' and args.dataset in ['cifar10', 'cifar100']: - net_glob = CNNCifar(args=args).to(args.device) - elif args.model == 'cnn' and args.dataset == 'mnist': - net_glob = CNNMnist(args=args).to(args.device) - elif args.model == 'resnet' and args.dataset in ['cifar10', 'cifar100']: - net_glob = ResnetCifar(args=args).to(args.device) - elif args.model == 'mlp': - len_in = 1 - for x in img_size: - len_in *= x - net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device) - else: - exit('Error: unrecognized model') + net_glob = get_model(args) + net_glob.train() print(net_glob) net_glob.train() - if args.load_fed: - fed_model_path = './save/keep/fed_{}_{}_iid{}_num{}_C{}_le{}_gn{}.npy'.format( - args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.grad_norm) - if len(args.load_fed_name) > 0: - fed_model_path = './save/keep/{}'.format(args.load_fed_name) - net_glob.load_state_dict(torch.load(fed_model_path)) total_num_layers = len(net_glob.weight_keys) w_glob_keys = net_glob.weight_keys[total_num_layers - args.num_layers_keep:] @@ -129,6 +62,7 @@ percentage_param = 100 * float(num_param_glob) / num_param_local print('# Params: {} (local), {} (global); Percentage {:.2f} ({}/{})'.format( num_param_local, num_param_glob, percentage_param, num_param_glob, num_param_local)) + # generate list of local models for each user net_local_list = [] for user_ix in range(args.num_users): @@ -136,104 +70,13 @@ criterion = nn.CrossEntropyLoss() - - def test_img_ensemble_all(): - probs_all = [] - preds_all = [] - for idx in range(args.num_users): - net_local = net_local_list[idx] - net_local.eval() - # _, _, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx) - acc, loss, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx) - # print('Local model: {}, loss: {}, acc: {}'.format(idx, loss, acc)) - probs_all.append(probs.detach()) - - preds = probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1) - preds_all.append(preds) - - labels = np.array(dataset_test.targets) - preds_probs = torch.mean(torch.stack(probs_all), dim=0) - - # ensemble metrics - preds_avg = preds_probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1) - loss_test = criterion(preds_probs, torch.tensor(labels).to(args.device)).item() - acc_test = (preds_avg == labels).mean() * 100 - - return loss_test, acc_test - - def test_img_local_all(): - acc_test_local = 0 - loss_test_local = 0 - for idx in range(args.num_users): - net_local = net_local_list[idx] - net_local.eval() - a, b = test_img_local(net_local, dataset_test, args, user_idx=idx, idxs=dict_users_test[idx]) - - acc_test_local += a - loss_test_local += b - acc_test_local /= args.num_users - loss_test_local /= args.num_users - - return acc_test_local, loss_test_local - - def test_img_avg_all(): - net_glob_temp = copy.deepcopy(net_glob) - w_keys_epoch = net_glob.state_dict().keys() - w_glob_temp = {} - for idx in range(args.num_users): - net_local = net_local_list[idx] - w_local = net_local.state_dict() - - if len(w_glob_temp) == 0: - w_glob_temp = copy.deepcopy(w_local) - else: - for k in w_keys_epoch: - w_glob_temp[k] += w_local[k] - - for k in w_keys_epoch: - w_glob_temp[k] = torch.div(w_glob_temp[k], args.num_users) - net_glob_temp.load_state_dict(w_glob_temp) - acc_test_avg, loss_test_avg = test_img(net_glob_temp, dataset_test, args) - - return acc_test_avg, loss_test_avg - - if args.local_ep_pretrain > 0: - # pretrain each local model - pretrain_save_path = 'pretrain/{}/{}_{}/user_{}/ep_{}/'.format(args.model, args.dataset, - 'iid' if args.iid else 'noniid', args.num_users, - args.local_ep_pretrain) - if not os.path.exists(pretrain_save_path): - os.makedirs(pretrain_save_path) - - print("\nPretraining local models...") - for idx in range(args.num_users): - net_local = net_local_list[idx] - net_local_path = os.path.join(pretrain_save_path, '{}.pt'.format(idx)) - if os.path.exists(net_local_path): # check if we have a saved model - net_local.load_state_dict(torch.load(net_local_path)) - else: - local = LocalUpdateMTL(args=args, dataset=dataset_train, idxs=dict_users_train[idx], pretrain=True) - w_local, loss = local.train(net=net_local.to(args.device)) - print('Local model {}, Train Epoch Loss: {:.4f}'.format(idx, loss)) - torch.save(net_local.state_dict(), net_local_path) - - print("Getting initial loss and acc...") - acc_test_local, loss_test_local = test_img_local_all() - acc_test_avg, loss_test_avg = test_img_avg_all() - loss_test, acc_test = test_img_ensemble_all() - - print('Initial Ensemble: Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}, Loss (ens) {:.3f}, Acc: (ens) {:.2f}, '.format( - loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test)) - # training + results_save_path = os.path.join(base_save_dir, 'results.csv') + loss_train = [] - cv_loss, cv_acc = [], [] - val_loss_pre, counter = 0, 0 net_best = None - best_loss = None - best_acc = None - best_epoch = None - val_acc_list, net_list = [], [] + best_acc = np.ones(args.num_users) * -1 + best_net_list = copy.deepcopy(net_local_list) lr = args.lr results = [] @@ -257,82 +100,49 @@ def test_img_avg_all(): idxs_users = np.random.choice(range(args.num_users), m, replace=False) W = torch.zeros((d, m)).cuda() - for i in range(m): - W_local = [net_local_list[i].state_dict()[key].flatten() for key in w_glob_keys] + # for idx, user in enumerate(idxs_users): + for idx, user in enumerate(range(10)): + W_local = [net_local_list[user].state_dict()[key].flatten() for key in w_glob_keys] W_local = torch.cat(W_local) - W[:, i] = W_local + W[:, idx] = W_local - if args.verbose: - print("Round {}: lr: {:.6f}, {}".format(iter, lr, idxs_users)) - for i, idx in enumerate(idxs_users): - local = LocalUpdateMTL(args=args, dataset=dataset_train, idxs=dict_users_train[idx]) - net_local = net_local_list[idx] + for idx, user in enumerate(idxs_users): + local = LocalUpdateMTL(args=args, dataset=dataset_train, idxs=dict_users_train[user]) + net_local = net_local_list[user] w_local, loss = local.train(net=net_local.to(args.device), lr=lr, - omega=omega, W_glob=W.clone(), i=i, w_glob_keys=w_glob_keys) + omega=omega, W_glob=W.clone(), idx=idx, w_glob_keys=w_glob_keys) loss_locals.append(copy.deepcopy(loss)) - if (iter+1) % int(args.num_users * args.frac): - lr *= args.lr_decay - loss_avg = sum(loss_locals) / len(loss_locals) loss_train.append(loss_avg) # eval - acc_test_local, loss_test_local = test_img_local_all() - acc_test_avg, loss_test_avg = test_img_avg_all() + acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True) - # if (iter + 1) % args.test_freq == 0: # this takes too much time, so we run it less frequently - if (iter + 1) > 2000: - loss_test, acc_test = test_img_ensemble_all() - print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}, Loss (ens) {:.3f}, Acc: (ens) {:.2f}, '.format( - iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test)) - results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test])) + for user in range(args.num_users): + if acc_test_local[user] > best_acc[user]: + best_acc[user] = acc_test_local[user] + best_net_list[user] = copy.deepcopy(net_local_list[user]) - else: - print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( - iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) - results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, np.nan, np.nan])) + model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user)) + torch.save(best_net_list[user].state_dict(), model_save_path) - final_results = np.array(results) - results_save_path = './log/lg_{}_{}_keep{}_iid{}_num{}_C{}_le{}_gn{}_pt{}_load{}_tfreq{}.npy'.format( - args.dataset, args.model, args.num_layers_keep, args.iid, args.num_users, args.frac, - args.local_ep, args.grad_norm, args.local_ep_pretrain, args.load_fed, args.test_freq) - np.save(results_save_path, final_results) - - if best_acc is None or acc_test_local > best_acc: - best_acc = acc_test_local - best_epoch = iter - - for user in range(args.num_users): - model_save_path = './save/{}/{}_{}_iid{}/'.format( - args.results_save, args.dataset, 0, args.iid) - if not os.path.exists(model_save_path): - os.makedirs(model_save_path) - - model_save_path = './save/{}/{}_{}_iid{}/user{}.pt'.format( - args.results_save, args.dataset, 0, args.iid, user) - torch.save(net_local_list[user].state_dict(), model_save_path) + acc_test_local, loss_test_local = test_img_local_all(best_net_list, args, dataset_test, dict_users_test) + acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, best_net_list, args, dataset_test) + print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( + iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) - for user in range(args.num_users): - model_save_path = './save/{}/{}_{}_iid{}/user{}.pt'.format( - args.results_save, args.dataset, 0, args.iid, user) - - net_local = net_local_list[idx] - net_local.load_state_dict(torch.load(model_save_path)) - loss_test, acc_test = test_img_ensemble_all() - - print('Best model, iter: {}, acc: {}, acc (ens): {}'.format(best_epoch, best_acc, acc_test)) + results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc.mean(), None, None])) + final_results = np.array(results) + final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results.to_csv(results_save_path, index=False) - # plot loss curve - # plt.figure() - # plt.plot(range(len(loss_train)), loss_train) - # plt.ylabel('train_loss') - # plt.savefig('./log/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) + acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(best_net_list, args, dataset_test) + print('Best model, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_acc, acc_test_ens_avg, acc_test_ens_maj)) - # testing - net_glob.eval() - acc_train, loss_train = test_img(net_glob, dataset_train, args) - acc_test, loss_test = test_img(net_glob, dataset_test, args) - print("Training accuracy: {:.2f}".format(acc_train)) - print("Testing accuracy: {:.2f}".format(acc_test)) \ No newline at end of file + results.append(np.array(['Final', None, None, best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj])) + final_results = np.array(results) + final_results = pd.DataFrame(final_results, + columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj']) + final_results.to_csv(results_save_path, index=False) \ No newline at end of file diff --git a/models/Update.py b/models/Update.py index 449f454..a2b1f2a 100644 --- a/models/Update.py +++ b/models/Update.py @@ -5,7 +5,9 @@ import torch from torch import nn from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm import math +import pdb class DatasetSplit(Dataset): def __init__(self, dataset, idxs): @@ -51,16 +53,6 @@ def train(self, net, idx=-1, lr=0.1): batch_loss.append(loss.item()) - if not self.pretrain and self.args.verbose and batch_idx % 300 == 0: - if idx < 0: - print('Update Epoch: {} [{}/{} ({:.0f}%)], Epoch Loss: {:.4f}, Batch Loss: {:.4f}'.format( - iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train), - sum(batch_loss)/len(batch_loss), loss.item())) - else: - print('Local model {}, Update Epoch: {} [{}/{} ({:.0f}%)], Epoch Loss: {:.4f}, Batch Loss: {:.4f}'.format( - idx, iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train), - sum(batch_loss)/len(batch_loss), loss.item())) - epoch_loss.append(sum(batch_loss)/len(batch_loss)) return net.state_dict(), sum(epoch_loss) / len(epoch_loss) @@ -74,7 +66,7 @@ def __init__(self, args, dataset=None, idxs=None, pretrain=False): self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) self.pretrain = pretrain - def train(self, net, idx=-1, lr=0.1, omega=None, W_glob=None, i=None, w_glob_keys=None): + def train(self, net, lr=0.1, omega=None, W_glob=None, idx=None, w_glob_keys=None): net.train() # train and update optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.5) @@ -98,12 +90,12 @@ def train(self, net, idx=-1, lr=0.1, omega=None, W_glob=None, i=None, w_glob_key W_local = [net.state_dict(keep_vars=True)[key].flatten() for key in w_glob_keys] W_local = torch.cat(W_local) - W[:, i] = W_local + W[:, idx] = W_local loss_regularizer = 0 loss_regularizer += W.norm() ** 2 - k = 10000 + k = 4000 for i in range(W.shape[0] // k): x = W[i * k:(i+1) * k, :] loss_regularizer += x.mm(omega).mm(x.T).trace() @@ -116,16 +108,6 @@ def train(self, net, idx=-1, lr=0.1, omega=None, W_glob=None, i=None, w_glob_key batch_loss.append(loss.item()) - if not self.pretrain and self.args.verbose and batch_idx % 300 == 0: - if idx < 0: - print('Update Epoch: {} [{}/{} ({:.0f}%)], Epoch Loss: {:.4f}, Batch Loss: {:.4f}'.format( - iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train), - sum(batch_loss)/len(batch_loss), loss.item())) - else: - print('Local model {}, Update Epoch: {} [{}/{} ({:.0f}%)], Epoch Loss: {:.4f}, Batch Loss: {:.4f}'.format( - idx, iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train), - sum(batch_loss)/len(batch_loss), loss.item())) - epoch_loss.append(sum(batch_loss)/len(batch_loss)) return net.state_dict(), sum(epoch_loss) / len(epoch_loss) diff --git a/models/test.py b/models/test.py index 084d22a..2df3fcc 100644 --- a/models/test.py +++ b/models/test.py @@ -2,10 +2,14 @@ # -*- coding: utf-8 -*- # @python: 3.6 +import copy +import numpy as np +from scipy import stats import torch from torch import nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset +import pdb class DatasetSplit(Dataset): def __init__(self, dataset, idxs): @@ -19,7 +23,6 @@ def __getitem__(self, item): image, label = self.dataset[self.idxs[item]] return image, label - def test_img(net_g, datatest, args, return_probs=False, user_idx=-1): net_g.eval() # testing @@ -84,3 +87,72 @@ def test_img_local(net_g, dataset, args, user_idx=-1, idxs=None): user_idx, test_loss, correct, len(data_loader.dataset), accuracy)) return accuracy, test_loss + +def test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=False): + acc_test_local = np.zeros(args.num_users) + loss_test_local = np.zeros(args.num_users) + for idx in range(args.num_users): + net_local = net_local_list[idx] + net_local.eval() + a, b = test_img_local(net_local, dataset_test, args, user_idx=idx, idxs=dict_users_test[idx]) + + acc_test_local[idx] = a + loss_test_local[idx] = b + + if return_all: + return acc_test_local, loss_test_local + return acc_test_local.mean(), loss_test_local.mean() + +def test_img_avg_all(net_glob, net_local_list, args, dataset_test, return_net=False): + net_glob_temp = copy.deepcopy(net_glob) + w_keys_epoch = net_glob.state_dict().keys() + w_glob_temp = {} + for idx in range(args.num_users): + net_local = net_local_list[idx] + w_local = net_local.state_dict() + + if len(w_glob_temp) == 0: + w_glob_temp = copy.deepcopy(w_local) + else: + for k in w_keys_epoch: + w_glob_temp[k] += w_local[k] + + for k in w_keys_epoch: + w_glob_temp[k] = torch.div(w_glob_temp[k], args.num_users) + net_glob_temp.load_state_dict(w_glob_temp) + acc_test_avg, loss_test_avg = test_img(net_glob_temp, dataset_test, args) + + if return_net: + return acc_test_avg, loss_test_avg, net_glob_temp + return acc_test_avg, loss_test_avg + +criterion = nn.CrossEntropyLoss() + +def test_img_ensemble_all(net_local_list, args, dataset_test): + probs_all = [] + preds_all = [] + for idx in range(args.num_users): + net_local = net_local_list[idx] + net_local.eval() + # _, _, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx) + acc, loss, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx) + # print('Local model: {}, loss: {}, acc: {}'.format(idx, loss, acc)) + probs_all.append(probs.detach()) + + preds = probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1) + preds_all.append(preds) + + labels = np.array(dataset_test.targets) + preds_probs = torch.mean(torch.stack(probs_all), dim=0) + + # ensemble (avg) metrics + preds_avg = preds_probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1) + loss_test = criterion(preds_probs, torch.tensor(labels).to(args.device)).item() + acc_test_avg = (preds_avg == labels).mean() * 100 + + # ensemble (maj) + preds_all = np.array(preds_all).T + preds_maj = stats.mode(preds_all, axis=1)[0].reshape(-1) + acc_test_maj = (preds_maj == labels).mean() * 100 + + return acc_test_avg, loss_test, acc_test_maj \ No newline at end of file diff --git a/scripts/cifar10_ablation.sh b/scripts/cifar10_ablation.sh new file mode 100755 index 0000000..14e4178 --- /dev/null +++ b/scripts/cifar10_ablation.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +SHARDS=2 +while getopts s:n: option; do + case "${option}" in + s) SHARDS=${OPTARG};; + esac +done + +for RUN in 1 2 3 4 5; do + for NUM_USERS in 20 50 200; do + python3 main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN} + + python3 main_local.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN} + + for FED_MODEL in 1000 1200 1400 1600 1800; do + python3 main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 \ + --results_save run${RUN} --load_fed best_${FED_MODEL}.pt + done + done +done + diff --git a/scripts/main_fed_cifar10.sh b/scripts/main_fed_cifar10.sh new file mode 100755 index 0000000..f14665a --- /dev/null +++ b/scripts/main_fed_cifar10.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +NUM_USERS=100 +SHARDS=2 +while getopts s:n: option; do + case "${option}" in + s) SHARDS=${OPTARG};; + n) NUM_USERS=${OPTARG};; + esac +done + + +for RUN in 1 2 3 4 5; do + python3 main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN} + python3 main_local.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN} +done + diff --git a/scripts/main_fed_cifar10_iid.sh b/scripts/main_fed_cifar10_iid.sh new file mode 100755 index 0000000..a1929d7 --- /dev/null +++ b/scripts/main_fed_cifar10_iid.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_USERS=100 +while getopts s:n: option; do + case "${option}" in + n) NUM_USERS=${OPTARG};; + esac +done + + +for RUN in 1 2 3 4 5; do + python3 main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN} --iid + python3 main_local.py --dataset cifar10 --model cnn --num_classes 10 --epochs 400 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN} --iid +done + diff --git a/scripts/main_fed_mnist.sh b/scripts/main_fed_mnist.sh new file mode 100755 index 0000000..dbe4002 --- /dev/null +++ b/scripts/main_fed_mnist.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +NUM_USERS=100 +SHARDS=2 +while getopts s:n: option; do + case "${option}" in + s) SHARDS=${OPTARG};; + n) NUM_USERS=${OPTARG};; + esac +done + +for RUN in 1 2 3 4 5; do + python3 main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN} + python3 main_local.py --dataset mnist --model mlp --num_classes 10 --epochs 100 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN} +done + diff --git a/scripts/main_fed_mnist_iid.sh b/scripts/main_fed_mnist_iid.sh new file mode 100755 index 0000000..754486d --- /dev/null +++ b/scripts/main_fed_mnist_iid.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +NUM_USERS=100 +while getopts s:n: option; do + case "${option}" in + n) NUM_USERS=${OPTARG};; + esac +done + +for RUN in 1 2 3 4 5; do + python3 main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN} --iid + python3 main_local.py --dataset mnist --model mlp --num_classes 10 --epochs 400 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN} --iid +done + diff --git a/scripts/main_lg_cifar10.sh b/scripts/main_lg_cifar10.sh new file mode 100755 index 0000000..a09a0f3 --- /dev/null +++ b/scripts/main_lg_cifar10.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +NUM_USERS=100 +SHARDS=2 +while getopts s:n: option; do + case "${option}" in + s) SHARDS=${OPTARG};; + n) NUM_USERS=${OPTARG};; + esac +done + +for FED_MODEL in 800 1000 1200 1400 1600 1800; do + for RUN in 1 2 3 4 5; do + python3 main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 \ + --results_save run${RUN} --load_fed best_${FED_MODEL}.pt + done +done + diff --git a/scripts/main_lg_cifar10_iid.sh b/scripts/main_lg_cifar10_iid.sh new file mode 100755 index 0000000..3abb075 --- /dev/null +++ b/scripts/main_lg_cifar10_iid.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_USERS=100 +while getopts s:n: option; do + case "${option}" in + n) NUM_USERS=${OPTARG};; + esac +done + +for FED_MODEL in 800 1000 1200 1400 1600 1800; do + for RUN in 1 2 3 4 5; do + python3 main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 \ + --results_save run${RUN} --load_fed best_${FED_MODEL}.pt --iid + done +done + diff --git a/scripts/main_lg_mnist.sh b/scripts/main_lg_mnist.sh new file mode 100755 index 0000000..65344b8 --- /dev/null +++ b/scripts/main_lg_mnist.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +NUM_USERS=100 +SHARDS=2 +while getopts s:n: option; do + case "${option}" in + s) SHARDS=${OPTARG};; + n) NUM_USERS=${OPTARG};; + esac +done + +for FED_MODEL in 400 500 600 700 800; do + for RUN in 1 2 3 4 5; do + python3 main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 \ + --results_save run${RUN} --load_fed best_${FED_MODEL}.pt + done +done + diff --git a/scripts/main_lg_mnist_iid.sh b/scripts/main_lg_mnist_iid.sh new file mode 100755 index 0000000..6c4da70 --- /dev/null +++ b/scripts/main_lg_mnist_iid.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_USERS=100 +while getopts s:n: option; do + case "${option}" in + n) NUM_USERS=${OPTARG};; + esac +done + +for FED_MODEL in 400 500 600 700 800; do + for RUN in 1 2 3 4 5; do + python3 main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 \ + --results_save run${RUN} --load_fed best_${FED_MODEL}.pt --iid + done +done + diff --git a/scripts/main_mtl_cifar10.sh b/scripts/main_mtl_cifar10.sh new file mode 100755 index 0000000..89baa3a --- /dev/null +++ b/scripts/main_mtl_cifar10.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_USERS=100 +SHARDS=2 +while getopts s:n: option; do + case "${option}" in + s) SHARDS=${OPTARG};; + n) NUM_USERS=${OPTARG};; + esac +done + +for RUN in 1 2 3 4 5; do + python3 main_mtl.py --dataset cifar10 --model cnn --num_classes 10 --epochs 1800 --lr 0.1 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 5 \ + --results_save run${RUN} +done + diff --git a/scripts/main_mtl_mnist.sh b/scripts/main_mtl_mnist.sh new file mode 100755 index 0000000..9282ff4 --- /dev/null +++ b/scripts/main_mtl_mnist.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_USERS=100 +SHARDS=2 +while getopts s:n: option; do + case "${option}" in + s) SHARDS=${OPTARG};; + n) NUM_USERS=${OPTARG};; + esac +done + +for RUN in 1 2 3 4 5; do + python3 main_mtl.py --dataset mnist --model mlp --num_classes 10 --epochs 800 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 5 \ + --results_save run${RUN} +done + diff --git a/scripts/mnist_ablation.sh b/scripts/mnist_ablation.sh new file mode 100755 index 0000000..095bff6 --- /dev/null +++ b/scripts/mnist_ablation.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +SHARDS=2 +while getopts s:n: option; do + case "${option}" in + s) SHARDS=${OPTARG};; + esac +done + +for RUN in 1 2 3 4 5; do + for NUM_USERS in 20 50 200; do + python3 main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN} + + python3 main_local.py --dataset mnist --model mlp --num_classes 10 --epochs 100 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN} + + for FED_MODEL in 400 500 600 700 800; do + for RUN in 1 2 3 4 5; do + python3 main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 \ + --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 \ + --results_save run${RUN} --load_fed best_${FED_MODEL}.pt + done + done + done +done + diff --git a/test_local_numusers.py b/test_local_numusers.py new file mode 100644 index 0000000..5647914 --- /dev/null +++ b/test_local_numusers.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Python version: 3.6 + +import copy +import os +import pickle +import pandas as pd +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader + +from utils.options import args_parser +from utils.train_utils import get_data, get_model +from models.Update import DatasetSplit +from models.test import test_img_local, test_img_local_all, test_img_avg_all, test_img_ensemble_all + +from torchvision import datasets, transforms +from models.Nets import MLP, CNNMnist, CNNCifar +from utils.sampling import iid, noniid, noniid_replace + +import pdb + +results = {} + +# manually set arguments +args = args_parser() +args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') +args.num_classes = 10 +args.epochs = 100 +if args.dataset == "mnist": + args.local_bs = 10 +else: + args.local_bs = 50 + +early_stopping = 20 + +args.shard_per_user = 2 +for num_users in [5, 10, 20, 50, 100, 200, 500, 1000]: + results[num_users] = [] + for run in range(10): + args.num_users = num_users + + # dataset + dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args) + + # model + net_glob = get_model(args) + net_glob.train() + + net_local_list = [] + for user_ix in range(args.num_users): + net_local_list.append(copy.deepcopy(net_glob)) + + criterion = nn.CrossEntropyLoss() + lr = args.lr + + # run each local model + for user, net_local in enumerate(net_local_list): + net_best = None + best_acc = None + + ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users_train[user]), batch_size=args.local_bs, shuffle=True) + optimizer = torch.optim.SGD(net_local.parameters(), lr=lr, momentum=0.5) + + epochs_since_improvement = 0 + for iter in range(args.epochs): + for batch_idx, (images, labels) in enumerate(ldr_train): + images, labels = images.to(args.device), labels.to(args.device) + net_local.zero_grad() + log_probs = net_local(images) + + loss = criterion(log_probs, labels) + loss.backward() + optimizer.step() + + acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user]) + + epochs_since_improvement += 1 + if best_acc is None or acc_test > best_acc: + best_acc = acc_test + net_best = copy.deepcopy(net_local) + epochs_since_improvement = 0 + + print('User {}, Epoch {}, Acc {:.2f}'.format(user, iter, acc_test)) + if epochs_since_improvement >= early_stopping: + print("Early stopping...") + break + + net_local_list[user] = net_best + + acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test) + results[num_users].append(acc_test_local) + + results_save_path = "save/{}/test_local.pkl".format(args.dataset) + with open(results_save_path, 'wb') as file: + pickle.dump(results, file) + +import pickle +import numpy as np + +# with open("save/mnist/test_local.pkl",'rb') as file: x = pickle.load(file) +with open("save/cifar10/test_local.pkl",'rb') as file: x = pickle.load(file) + +for key, value in x.items(): print(key, np.array(value).mean()) + +for key, value in x.items(): print(key, np.array(value).std()) + diff --git a/utils/options.py b/utils/options.py index 292cf7e..d61052b 100644 --- a/utils/options.py +++ b/utils/options.py @@ -9,6 +9,7 @@ def args_parser(): # federated arguments parser.add_argument('--epochs', type=int, default=10, help="rounds of training") parser.add_argument('--num_users', type=int, default=100, help="number of users: K") + parser.add_argument('--shard_per_user', type=int, default=2, help="classes per user") parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C") parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E") parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B") diff --git a/utils/sampling.py b/utils/sampling.py index 920689b..65cae04 100755 --- a/utils/sampling.py +++ b/utils/sampling.py @@ -2,9 +2,12 @@ # -*- coding: utf-8 -*- # Python version: 3.6 - +import math +import random +from itertools import permutations import numpy as np -from torchvision import datasets, transforms +import torch +import pdb def fair_iid(dataset, num_users): """ @@ -64,7 +67,7 @@ def fair_noniid(train_data, num_users, num_shards=200, num_imgs=300, train=True, return dict_users, rand_set_all -def mnist_iid(dataset, num_users): +def iid(dataset, num_users): """ Sample I.I.D. client data from MNIST dataset :param dataset: @@ -78,140 +81,95 @@ def mnist_iid(dataset, num_users): all_idxs = list(set(all_idxs) - dict_users[i]) return dict_users - -def mnist_noniid(dataset, num_users, num_shards=200, num_imgs=300, train=True, rand_set_all=[]): +def noniid(dataset, num_users, shard_per_user, rand_set_all=[]): """ Sample non-I.I.D client data from MNIST dataset :param dataset: :param num_users: :return: """ - - assert num_shards % num_users == 0 - shard_per_user = int(num_shards / num_users) - - idx_shard = [i for i in range(num_shards)] dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} - idxs = np.arange(num_shards*num_imgs) - labels = np.array(dataset.targets) - - assert num_shards * num_imgs == len(labels) - # sort labels - idxs_labels = np.vstack((idxs, labels)) - idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] - idxs = idxs_labels[0,:] + idxs_dict = {} + for i in range(len(dataset)): + label = torch.tensor(dataset.targets[i]).item() + if label not in idxs_dict.keys(): + idxs_dict[label] = [] + idxs_dict[label].append(i) + + num_classes = len(np.unique(dataset.targets)) + shard_per_class = int(shard_per_user * num_users / num_classes) + for label in idxs_dict.keys(): + x = idxs_dict[label] + num_leftover = len(x) % shard_per_class + leftover = x[-num_leftover:] if num_leftover > 0 else [] + x = np.array(x[:-num_leftover]) if num_leftover > 0 else np.array(x) + x = x.reshape((shard_per_class, -1)) + x = list(x) + + for i, idx in enumerate(leftover): + x[i] = np.concatenate([x[i], [idx]]) + idxs_dict[label] = x - # divide and assign if len(rand_set_all) == 0: - for i in range(num_users): - rand_set = set(np.random.choice(idx_shard, shard_per_user, replace=False)) - for rand in rand_set: - rand_set_all.append(rand) - - idx_shard = list(set(idx_shard) - rand_set) # remove shards from possible choices for other users - for rand in rand_set: - dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) - - else: # this only works if the train and test set have the same distribution of labels - for i in range(num_users): - rand_set = rand_set_all[i*shard_per_user: (i+1)*shard_per_user] - for rand in rand_set: - dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) - - return dict_users, rand_set_all - - -def cifar10_iid(dataset, num_users): - """ - Sample I.I.D. client data from CIFAR10 dataset - :param dataset: - :param num_users: - :return: dict of image index - """ - num_items = int(len(dataset)/num_users) - dict_users, all_idxs = {}, [i for i in range(len(dataset))] - for i in range(num_users): - dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) - all_idxs = list(set(all_idxs) - dict_users[i]) - return dict_users - -''' -def cifar10_noniid(dataset, num_users): - """ - Sample non-I.I.D client data from MNIST dataset - :param dataset: - :param num_users: - :return: - """ - num_shards, num_imgs = 200, 250 - idx_shard = [i for i in range(num_shards)] - dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} - idxs = np.arange(num_shards*num_imgs) - labels = np.array(dataset.train_labels) - - # sort labels - idxs_labels = np.vstack((idxs, labels)) - idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] - idxs = idxs_labels[0,:] + rand_set_all = list(range(num_classes)) * shard_per_class + random.shuffle(rand_set_all) + rand_set_all = np.array(rand_set_all).reshape((num_users, -1)) # divide and assign for i in range(num_users): - rand_set = set(np.random.choice(idx_shard, 2, replace=False)) - idx_shard = list(set(idx_shard) - rand_set) - for rand in rand_set: - dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) - return dict_users -''' + rand_set_label = rand_set_all[i] + rand_set = [] + for label in rand_set_label: + idx = np.random.choice(len(idxs_dict[label]), replace=False) + rand_set.append(idxs_dict[label].pop(idx)) + dict_users[i] = np.concatenate(rand_set) + + test = [] + for key, value in dict_users.items(): + x = np.unique(torch.tensor(dataset.targets)[value]) + assert(len(x)) <= shard_per_user + test.append(value) + test = np.concatenate(test) + assert(len(test) == len(dataset)) + assert(len(set(list(test))) == len(dataset)) + + return dict_users, rand_set_all -def cifar10_noniid(dataset, num_users, num_shards=200, num_imgs=250, train=True, rand_set_all=[]): +def noniid_replace(dataset, num_users, shard_per_user, rand_set_all=[]): """ Sample non-I.I.D client data from MNIST dataset :param dataset: :param num_users: :return: """ - - assert num_shards % num_users == 0 - shard_per_user = int(num_shards / num_users) - - idx_shard = [i for i in range(num_shards)] + imgs_per_shard = int(len(dataset) / (num_users * shard_per_user)) dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} - idxs = np.arange(num_shards*num_imgs) - labels = np.array(dataset.targets) - - assert num_shards * num_imgs == len(labels) - # sort labels - idxs_labels = np.vstack((idxs, labels)) - idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] - idxs = idxs_labels[0,:] + idxs_dict = {} + for i in range(len(dataset)): + label = torch.tensor(dataset.targets[i]).item() + if label not in idxs_dict.keys(): + idxs_dict[label] = [] + idxs_dict[label].append(i) - # divide and assign + num_classes = len(np.unique(dataset.targets)) if len(rand_set_all) == 0: for i in range(num_users): - rand_set = set(np.random.choice(idx_shard, shard_per_user, replace=False)) - for rand in rand_set: - rand_set_all.append(rand) - - idx_shard = list(set(idx_shard) - rand_set) # remove shards from possible choices for other users - for rand in rand_set: - dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) + x = np.random.choice(np.arange(num_classes), shard_per_user, replace=False) + rand_set_all.append(x) - else: # this only works if the train and test set have the same distribution of labels - for i in range(num_users): - rand_set = rand_set_all[i*shard_per_user: (i+1)*shard_per_user] - for rand in rand_set: - dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) - - return dict_users, rand_set_all - - -if __name__ == '__main__': - dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])) - num = 100 - d = mnist_noniid(dataset_train, num) + # divide and assign + for i in range(num_users): + rand_set_label = rand_set_all[i] + rand_set = [] + for label in rand_set_label: + pdb.set_trace() + x = np.random.choice(idxs_dict[label], imgs_per_shard, replace=False) + rand_set.append(x) + dict_users[i] = np.concatenate(rand_set) + + for key, value in dict_users.items(): + assert(len(np.unique(torch.tensor(dataset.targets)[value]))) == shard_per_user + + return dict_users, rand_set_all \ No newline at end of file diff --git a/utils/train_utils.py b/utils/train_utils.py index 2c227d3..8015190 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -1,46 +1,59 @@ from torchvision import datasets, transforms from models.Nets import MLP, CNNMnist, CNNCifar -from utils.sampling import mnist_iid, mnist_noniid, cifar10_iid, cifar10_noniid +from utils.sampling import iid, noniid -def get_data(args, rand_set_all=[]): - trans_mnist = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - trans_cifar_val = transforms.Compose([transforms.ToTensor(), +trans_mnist = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))]) +trans_cifar10_train = transforms.Compose([transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) +trans_cifar10_val = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) +trans_cifar100_train = transforms.Compose([transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.507, 0.487, 0.441], + std=[0.267, 0.256, 0.276])]) +trans_cifar100_val = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean=[0.507, 0.487, 0.441], + std=[0.267, 0.256, 0.276])]) +def get_data(args): if args.dataset == 'mnist': dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist) dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist) # sample users if args.iid: - dict_users_train = mnist_iid(dataset_train, args.num_users) - dict_users_test = mnist_iid(dataset_test, args.num_users) + dict_users_train = iid(dataset_train, args.num_users) + dict_users_test = iid(dataset_test, args.num_users) else: - dict_users_train, rand_set_all = mnist_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=300, - train=True, rand_set_all=rand_set_all) - dict_users_test, rand_set_all = mnist_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50, - train=False, rand_set_all=rand_set_all) + dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user) + dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all) elif args.dataset == 'cifar10': - dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar_train) - dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar_val) + dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar10_train) + dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar10_val) if args.iid: - dict_users_train = cifar10_iid(dataset_train, args.num_users) - dict_users_test = cifar10_iid(dataset_test, args.num_users) + dict_users_train = iid(dataset_train, args.num_users) + dict_users_test = iid(dataset_test, args.num_users) else: - dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250, - train=True, rand_set_all=rand_set_all) - dict_users_test, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250, - train=True, rand_set_all=rand_set_all) + dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user) + dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all) + elif args.dataset == 'cifar100': + dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar100_train) + dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar100_val) + if args.iid: + dict_users_train = iid(dataset_train, args.num_users) + dict_users_test = iid(dataset_test, args.num_users) + else: + dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user) + dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all) else: exit('Error: unrecognized dataset') - return dataset_train, dataset_test, dict_users_train, dict_users_test, rand_set_all + return dataset_train, dataset_test, dict_users_train, dict_users_test def get_model(args): if args.model == 'cnn' and args.dataset in ['cifar10', 'cifar100']: