diff --git a/.gitignore b/.gitignore
index 0d20b64..3ac6e56 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,140 @@
-*.pyc
+.idea/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
diff --git a/get_results.ipynb b/get_results.ipynb
new file mode 100644
index 0000000..d511a88
--- /dev/null
+++ b/get_results.ipynb
@@ -0,0 +1,385 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import re\n",
+ "import numpy as np\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "iid = True\n",
+ "num_users = 100\n",
+ "frac = 0.1\n",
+ "local_ep = 1\n",
+ "\n",
+ "dataset = 'mnist'\n",
+ "# dataset = 'cifar10'\n",
+ "# dataset = 'cifar100'\n",
+ "\n",
+ "shard_per_user = 10\n",
+ "\n",
+ "if dataset == 'mnist':\n",
+ " model = 'mlp'\n",
+ " rd_lg = 100\n",
+ " rd_fed = 800 + int(rd_lg*0.15)\n",
+ "elif dataset == 'cifar10':\n",
+ " model = 'cnn'\n",
+ " rd_lg = 100\n",
+ " rd_fed = 1800 + int(rd_lg*0.04)\n",
+ "elif dataset == 'cifar100':\n",
+ " model = 'cnn'\n",
+ " rd_fed = 1800\n",
+ " rd_lg = 100"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/'.format(\n",
+ " dataset, model, iid, num_users, frac, local_ep, shard_per_user)\n",
+ "runs = os.listdir(base_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "acc_fed = np.zeros(len(runs))\n",
+ "acc_local_localtest = np.zeros(len(runs))\n",
+ "acc_local_newtest_avg = np.zeros(len(runs))\n",
+ "acc_local_newtest_ens = np.zeros(len(runs))\n",
+ "lg_metrics = {}\n",
+ "\n",
+ "for idx, run in enumerate(runs):\n",
+ " # FedAvg\n",
+ " base_dir_fed = os.path.join(base_dir, \"{}/fed\".format(run))\n",
+ " results_path_fed = os.path.join(base_dir_fed, \"results.csv\")\n",
+ " df_fed = pd.read_csv(results_path_fed)\n",
+ "\n",
+ " acc_fed[idx] = df_fed.loc[rd_fed - 1]['best_acc']\n",
+ " \n",
+ " # LocalOnly\n",
+ " base_dir_local = os.path.join(base_dir, \"{}/local\".format(run))\n",
+ " results_path_local = os.path.join(base_dir_local, \"results.csv\")\n",
+ " df_local = pd.read_csv(results_path_local)\n",
+ "\n",
+ " acc_local_localtest[idx] = df_local.loc[0]['acc_test_local']\n",
+ " acc_local_newtest_avg[idx] = df_local.loc[0]['acc_test_avg']\n",
+ " if 'acc_test_ens' in df_local.columns:\n",
+ " acc_local_newtest_ens[idx] = df_local.loc[0]['acc_test_ens']\n",
+ " else:\n",
+ " acc_local_newtest_ens[idx] = df_local.loc[0]['acc_test_ens_avg']\n",
+ " \n",
+ " # LGFed\n",
+ " base_dir_lg = os.path.join(base_dir, \"{}/lg/\".format(run))\n",
+ " lg_runs = os.listdir(base_dir_lg)\n",
+ " for lg_run in lg_runs:\n",
+ " results_path_lg = os.path.join(base_dir_lg, \"{}/results.csv\".format(lg_run))\n",
+ " df_lg = pd.read_csv(results_path_lg)\n",
+ " \n",
+ " load_fed = int(re.split('best_|.pt', lg_run)[1])\n",
+ " if load_fed not in lg_metrics.keys():\n",
+ " lg_metrics[load_fed] = {'acc_local': np.zeros(len(runs)),\n",
+ " 'acc_avg': np.zeros(len(runs)),\n",
+ " 'acc_ens': np.zeros(len(runs))}\n",
+ " \n",
+ " x = df_lg.loc[rd_lg]['best_acc_local']\n",
+ " lg_metrics[load_fed]['acc_local'][idx] = x\n",
+ " idx_acc_local = df_lg[df_lg['best_acc_local'] == x].index[0]\n",
+ " lg_metrics[load_fed]['acc_avg'][idx] = df_lg.loc[idx_acc_local]['acc_test_avg']\n",
+ " if 'acc_test_ens' in df_lg.columns:\n",
+ " lg_metrics[load_fed]['acc_ens'][idx] = df_lg['acc_test_ens'].values[-1]\n",
+ " else:\n",
+ " lg_metrics[load_fed]['acc_ens'][idx] = df_lg['acc_test_ens_avg'].values[-1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "columns = [\"Run\", \"Local Test\", \"New Test (avg)\", \"New Test (ens)\", \"FedAvg Rounds\", \"LG Rounds\"]\n",
+ "results = []"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "localonly:\t 88.03 +- 0.37\n",
+ "localonly_avg:\t 86.24 +- 0.87\n",
+ "localonly_ens:\t 91.15 +- 0.27\n"
+ ]
+ }
+ ],
+ "source": [
+ "str_acc_local_localtest = \"{:.2f} +- {:.2f}\".format(acc_local_localtest.mean(), acc_local_localtest.std())\n",
+ "str_acc_local_newtest_avg = \"{:.2f} +- {:.2f}\".format(acc_local_newtest_avg.mean(), acc_local_newtest_avg.std())\n",
+ "str_acc_local_newtest_ens = \"{:.2f} +- {:.2f}\".format(acc_local_newtest_ens.mean(), acc_local_newtest_ens.std())\n",
+ "\n",
+ "print(\"localonly:\\t\", str_acc_local_localtest)\n",
+ "print(\"localonly_avg:\\t\", str_acc_local_newtest_avg)\n",
+ "print(\"localonly_ens:\\t\", str_acc_local_newtest_ens)\n",
+ "\n",
+ "results.append([\"LocalOnly\", str_acc_local_localtest, str_acc_local_newtest_avg, str_acc_local_newtest_ens, 0, 0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "300\n",
+ "acc_local:\t97.45 +- 0.12\n",
+ "acc_avg:\t97.50 +- 0.10\n",
+ "acc_ens:\t97.49 +- 0.09\n",
+ "400\n",
+ "acc_local:\t97.59 +- 0.08\n",
+ "acc_avg:\t97.61 +- 0.08\n",
+ "acc_ens:\t97.62 +- 0.08\n",
+ "500\n",
+ "acc_local:\t97.78 +- 0.13\n",
+ "acc_avg:\t97.82 +- 0.14\n",
+ "acc_ens:\t97.82 +- 0.13\n",
+ "600\n",
+ "acc_local:\t97.84 +- 0.10\n",
+ "acc_avg:\t97.86 +- 0.08\n",
+ "acc_ens:\t97.87 +- 0.10\n",
+ "700\n",
+ "acc_local:\t97.85 +- 0.09\n",
+ "acc_avg:\t97.88 +- 0.09\n",
+ "acc_ens:\t97.89 +- 0.09\n",
+ "800\n",
+ "acc_local:\t97.91 +- 0.10\n",
+ "acc_avg:\t97.93 +- 0.07\n",
+ "acc_ens:\t97.93 +- 0.07\n"
+ ]
+ }
+ ],
+ "source": [
+ "for lg_run in sorted(lg_metrics.keys()):\n",
+ " x = [\"LG-FedAvg\"]\n",
+ " print(lg_run)\n",
+ " for array in ['acc_local', 'acc_avg', 'acc_ens']:\n",
+ " mean = lg_metrics[lg_run][array].mean()\n",
+ " std = lg_metrics[lg_run][array].std()\n",
+ " str_acc = \"{:.2f} +- {:.2f}\".format(mean, std)\n",
+ " print(\"{}:\\t{}\".format(array, str_acc))\n",
+ " \n",
+ " x.append(str_acc)\n",
+ " x.append(lg_run)\n",
+ " x.append(rd_lg)\n",
+ " results.append(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "fed:\t 97.93 +- 0.08\n"
+ ]
+ }
+ ],
+ "source": [
+ "str_acc_fed = \"{:.2f} +- {:.2f}\".format(acc_fed.mean(), acc_fed.std())\n",
+ "print(\"fed:\\t\", str_acc_fed)\n",
+ "results.append([\"FedAvg\", str_acc_fed, str_acc_fed, str_acc_fed, rd_fed, 0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Local Test | \n",
+ " New Test (avg) | \n",
+ " New Test (ens) | \n",
+ " FedAvg Rounds | \n",
+ " LG Rounds | \n",
+ "
\n",
+ " \n",
+ " Run | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " LocalOnly | \n",
+ " 88.03 +- 0.37 | \n",
+ " 86.24 +- 0.87 | \n",
+ " 91.15 +- 0.27 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " LG-FedAvg | \n",
+ " 97.45 +- 0.12 | \n",
+ " 97.50 +- 0.10 | \n",
+ " 97.49 +- 0.09 | \n",
+ " 300 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " LG-FedAvg | \n",
+ " 97.59 +- 0.08 | \n",
+ " 97.61 +- 0.08 | \n",
+ " 97.62 +- 0.08 | \n",
+ " 400 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " LG-FedAvg | \n",
+ " 97.78 +- 0.13 | \n",
+ " 97.82 +- 0.14 | \n",
+ " 97.82 +- 0.13 | \n",
+ " 500 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " LG-FedAvg | \n",
+ " 97.84 +- 0.10 | \n",
+ " 97.86 +- 0.08 | \n",
+ " 97.87 +- 0.10 | \n",
+ " 600 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " LG-FedAvg | \n",
+ " 97.85 +- 0.09 | \n",
+ " 97.88 +- 0.09 | \n",
+ " 97.89 +- 0.09 | \n",
+ " 700 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " LG-FedAvg | \n",
+ " 97.91 +- 0.10 | \n",
+ " 97.93 +- 0.07 | \n",
+ " 97.93 +- 0.07 | \n",
+ " 800 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " FedAvg | \n",
+ " 97.93 +- 0.08 | \n",
+ " 97.93 +- 0.08 | \n",
+ " 97.93 +- 0.08 | \n",
+ " 815 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Local Test New Test (avg) New Test (ens) FedAvg Rounds \\\n",
+ "Run \n",
+ "LocalOnly 88.03 +- 0.37 86.24 +- 0.87 91.15 +- 0.27 0 \n",
+ "LG-FedAvg 97.45 +- 0.12 97.50 +- 0.10 97.49 +- 0.09 300 \n",
+ "LG-FedAvg 97.59 +- 0.08 97.61 +- 0.08 97.62 +- 0.08 400 \n",
+ "LG-FedAvg 97.78 +- 0.13 97.82 +- 0.14 97.82 +- 0.13 500 \n",
+ "LG-FedAvg 97.84 +- 0.10 97.86 +- 0.08 97.87 +- 0.10 600 \n",
+ "LG-FedAvg 97.85 +- 0.09 97.88 +- 0.09 97.89 +- 0.09 700 \n",
+ "LG-FedAvg 97.91 +- 0.10 97.93 +- 0.07 97.93 +- 0.07 800 \n",
+ "FedAvg 97.93 +- 0.08 97.93 +- 0.08 97.93 +- 0.08 815 \n",
+ "\n",
+ " LG Rounds \n",
+ "Run \n",
+ "LocalOnly 0 \n",
+ "LG-FedAvg 100 \n",
+ "LG-FedAvg 100 \n",
+ "LG-FedAvg 100 \n",
+ "LG-FedAvg 100 \n",
+ "LG-FedAvg 100 \n",
+ "LG-FedAvg 100 \n",
+ "FedAvg 0 "
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.DataFrame(results, columns=columns).set_index(\"Run\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/main_fed.py b/main_fed.py
index 428da6a..013a3b8 100644
--- a/main_fed.py
+++ b/main_fed.py
@@ -3,7 +3,9 @@
# Python version: 3.6
import copy
+import pickle
import numpy as np
+import pandas as pd
import torch
from utils.options import args_parser
@@ -19,21 +21,22 @@
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
- base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/{}/'.format(
- args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.results_save)
+ base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
+ args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)
if not os.path.exists(os.path.join(base_dir, 'fed')):
os.makedirs(os.path.join(base_dir, 'fed'), exist_ok=True)
- dataset_train, dataset_test, dict_users_train, _, rand_set_all = get_data(args)
- rand_save_path = os.path.join(base_dir, 'randset.npy')
- np.save(rand_save_path, rand_set_all)
+ dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
+ dict_save_path = os.path.join(base_dir, 'dict_users.pkl')
+ with open(dict_save_path, 'wb') as handle:
+ pickle.dump((dict_users_train, dict_users_test), handle)
# build model
net_glob = get_model(args)
net_glob.train()
# training
- results_save_path = os.path.join(base_dir, 'fed/results.npy')
+ results_save_path = os.path.join(base_dir, 'fed/results.csv')
loss_train = []
net_best = None
@@ -83,15 +86,25 @@
print('Round {:3d}, Average loss {:.3f}, Test loss {:.3f}, Test accuracy: {:.2f}'.format(
iter, loss_avg, loss_test, acc_test))
- results.append(np.array([iter, loss_avg, loss_test, acc_test]))
- final_results = np.array(results)
- np.save(results_save_path, final_results)
- model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter))
if best_acc is None or acc_test > best_acc:
+ net_best = copy.deepcopy(net_glob)
best_acc = acc_test
best_epoch = iter
- if iter > args.start_saving:
- torch.save(net_glob.state_dict(), model_save_path)
+
+ # if (iter + 1) > args.start_saving:
+ # model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter + 1))
+ # torch.save(net_glob.state_dict(), model_save_path)
+
+ results.append(np.array([iter, loss_avg, loss_test, acc_test, best_acc]))
+ final_results = np.array(results)
+ final_results = pd.DataFrame(final_results, columns=['epoch', 'loss_avg', 'loss_test', 'acc_test', 'best_acc'])
+ final_results.to_csv(results_save_path, index=False)
+
+ if (iter + 1) % 50 == 0:
+ best_save_path = os.path.join(base_dir, 'fed/best_{}.pt'.format(iter + 1))
+ model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter + 1))
+ torch.save(net_best.state_dict(), best_save_path)
+ torch.save(net_glob.state_dict(), model_save_path)
print('Best model, iter: {}, acc: {}'.format(best_epoch, best_acc))
\ No newline at end of file
diff --git a/main_lg.py b/main_lg.py
index be975b7..ae704fc 100644
--- a/main_lg.py
+++ b/main_lg.py
@@ -4,15 +4,16 @@
import copy
import os
+import pickle
import itertools
+import pandas as pd
import numpy as np
import torch
-from torch import nn
from utils.options import args_parser
from utils.train_utils import get_data, get_model
from models.Update import LocalUpdate
-from models.test import test_img, test_img_local
+from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all
import pdb
@@ -21,26 +22,24 @@
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
- base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/{}/'.format(
- args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.results_save)
- if not os.path.exists(os.path.join(base_dir, 'lg')):
- os.makedirs(os.path.join(base_dir, 'lg'), exist_ok=True)
+ base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
+ args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)
- rand_set_all = []
- if len(args.load_fed) > 0:
- rand_save_path = os.path.join(base_dir, 'randset.npy')
- rand_set_all= np.load(rand_save_path)
+ assert(len(args.load_fed) > 0)
+ base_save_dir = os.path.join(base_dir, 'lg2/{}'.format(args.load_fed))
+ if not os.path.exists(base_save_dir):
+ os.makedirs(base_save_dir, exist_ok=True)
- dataset_train, dataset_test, dict_users_train, dict_users_test, rand_set_all = get_data(args, rand_set_all=rand_set_all)
- rand_save_path = os.path.join(base_dir, 'randset.npy')
- np.save(rand_save_path, rand_set_all)
+ dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
+ dict_save_path = os.path.join(base_dir, 'dict_users.pkl')
+ with open(dict_save_path, 'rb') as handle:
+ dict_users_train, dict_users_test = pickle.load(handle)
# build model
net_glob = get_model(args)
net_glob.train()
- if len(args.load_fed) > 0:
- fed_model_path = os.path.join(base_dir, 'fed/{}'.format(args.load_fed))
+ fed_model_path = os.path.join(base_dir, 'fed/{}'.format(args.load_fed))
net_glob.load_state_dict(torch.load(fed_model_path))
total_num_layers = len(net_glob.weight_keys)
@@ -56,96 +55,36 @@
percentage_param = 100 * float(num_param_glob) / num_param_local
print('# Params: {} (local), {} (global); Percentage {:.2f} ({}/{})'.format(
num_param_local, num_param_glob, percentage_param, num_param_glob, num_param_local))
+
# generate list of local models for each user
net_local_list = []
- for user_ix in range(args.num_users):
+ for user in range(args.num_users):
net_local_list.append(copy.deepcopy(net_glob))
- criterion = nn.CrossEntropyLoss()
-
-
- def test_img_ensemble_all():
- probs_all = []
- preds_all = []
- for idx in range(args.num_users):
- net_local = net_local_list[idx]
- net_local.eval()
- # _, _, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx)
- acc, loss, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx)
- # print('Local model: {}, loss: {}, acc: {}'.format(idx, loss, acc))
- probs_all.append(probs.detach())
-
- preds = probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
- preds_all.append(preds)
-
- labels = np.array(dataset_test.targets)
- preds_probs = torch.mean(torch.stack(probs_all), dim=0)
-
- # ensemble metrics
- preds_avg = preds_probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
- loss_test = criterion(preds_probs, torch.tensor(labels).to(args.device)).item()
- acc_test = (preds_avg == labels).mean() * 100
-
- return loss_test, acc_test
-
- def test_img_local_all():
- acc_test_local = 0
- loss_test_local = 0
- for idx in range(args.num_users):
- net_local = net_local_list[idx]
- net_local.eval()
- a, b = test_img_local(net_local, dataset_test, args, user_idx=idx, idxs=dict_users_test[idx])
-
- acc_test_local += a
- loss_test_local += b
- acc_test_local /= args.num_users
- loss_test_local /= args.num_users
-
- return acc_test_local, loss_test_local
-
- def test_img_avg_all():
- net_glob_temp = copy.deepcopy(net_glob)
- w_keys_epoch = net_glob.state_dict().keys()
- w_glob_temp = {}
- for idx in range(args.num_users):
- net_local = net_local_list[idx]
- w_local = net_local.state_dict()
-
- if len(w_glob_temp) == 0:
- w_glob_temp = copy.deepcopy(w_local)
- else:
- for k in w_keys_epoch:
- w_glob_temp[k] += w_local[k]
-
- for k in w_keys_epoch:
- w_glob_temp[k] = torch.div(w_glob_temp[k], args.num_users)
- net_glob_temp.load_state_dict(w_glob_temp)
- acc_test_avg, loss_test_avg = test_img(net_glob_temp, dataset_test, args)
-
- return acc_test_avg, loss_test_avg
+ acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test)
+ acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test)
# training
- results_save_path = os.path.join(base_dir, 'lg/results.npy')
+ results_save_path = os.path.join(base_save_dir, 'results.csv')
loss_train = []
net_best = None
- best_loss = None
- best_acc = None
- best_epoch = None
+ best_acc = np.ones(args.num_users) * acc_test_local
+ best_net_list = copy.deepcopy(net_local_list)
lr = args.lr
results = []
+ results.append(np.array([-1, acc_test_local, acc_test_avg, best_acc.mean(), None, None]))
+ print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format(
+ -1, -1, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg))
for iter in range(args.epochs):
w_glob = {}
loss_locals = []
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
- # w_keys_epoch = net_glob.state_dict().keys() if (iter + 1) % 25 == 0 else w_glob_keys
w_keys_epoch = w_glob_keys
- if args.verbose:
- print("Round {}: lr: {:.6f}, {}".format(iter, lr, idxs_users))
for idx in idxs_users:
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx])
net_local = net_local_list[idx]
@@ -163,9 +102,6 @@ def test_img_avg_all():
for k in w_keys_epoch:
w_glob[k] += w_local[k]
- if (iter+1) % int(args.num_users * args.frac):
- lr *= args.lr_decay
-
# get weighted average for global weights
for k in w_keys_epoch:
w_glob[k] = torch.div(w_glob[k], m)
@@ -186,29 +122,31 @@ def test_img_avg_all():
loss_train.append(loss_avg)
# eval
- acc_test_local, loss_test_local = test_img_local_all()
- acc_test_avg, loss_test_avg = test_img_avg_all()
+ acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True)
+
+ for user in range(args.num_users):
+ if acc_test_local[user] > best_acc[user]:
+ best_acc[user] = acc_test_local[user]
+ best_net_list[user] = copy.deepcopy(net_local_list[user])
+
+ model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user))
+ torch.save(best_net_list[user].state_dict(), model_save_path)
+ acc_test_local, loss_test_local = test_img_local_all(best_net_list, args, dataset_test, dict_users_test)
+ acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, best_net_list, args, dataset_test)
print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format(
iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg))
- results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg]))
+ results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc.mean(), None, None]))
final_results = np.array(results)
- np.save(results_save_path, final_results)
-
- if best_acc is None or acc_test_local > best_acc:
- best_acc = acc_test_local
- best_epoch = iter
-
- for user in range(args.num_users):
- model_save_path = os.path.join(base_dir, 'lg/model_user{}.pt'.format(user))
- torch.save(net_local_list[user].state_dict(), model_save_path)
-
- for user in range(args.num_users):
- model_save_path = os.path.join(base_dir, 'lg/model_user{}.pt'.format(user))
+ final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'])
+ final_results.to_csv(results_save_path, index=False)
- net_local = net_local_list[idx]
- net_local.load_state_dict(torch.load(model_save_path))
- loss_test, acc_test = test_img_ensemble_all()
+ acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(best_net_list, args, dataset_test)
+ print('Best model, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj))
- print('Best model, iter: {}, acc: {}, acc (ens): {}'.format(best_epoch, best_acc, acc_test))
\ No newline at end of file
+ results.append(np.array(['Final', None, None, best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj]))
+ final_results = np.array(results)
+ final_results = pd.DataFrame(final_results,
+ columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'])
+ final_results.to_csv(results_save_path, index=False)
\ No newline at end of file
diff --git a/main_lg_backup.py b/main_lg_backup.py
new file mode 100644
index 0000000..c1a0583
--- /dev/null
+++ b/main_lg_backup.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+import copy
+import os
+import pickle
+import itertools
+import pandas as pd
+import numpy as np
+import torch
+
+from utils.options import args_parser
+from utils.train_utils import get_data, get_model
+from models.Update import LocalUpdate
+from models.test import test_img_local_all, test_img_avg_all, test_img_ensemble_all
+
+import pdb
+
+if __name__ == '__main__':
+ # parse args
+ args = args_parser()
+ args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
+
+ base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
+ args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)
+
+ assert(len(args.load_fed) > 0)
+ base_save_dir = os.path.join(base_dir, 'lg/{}'.format(args.load_fed))
+ if not os.path.exists(base_save_dir):
+ os.makedirs(base_save_dir, exist_ok=True)
+
+ dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
+ dict_save_path = os.path.join(base_dir, 'dict_users.pkl')
+ with open(dict_save_path, 'rb') as handle:
+ dict_users_train, dict_users_test = pickle.load(handle)
+
+ # build model
+ net_glob = get_model(args)
+ net_glob.train()
+
+ fed_model_path = os.path.join(base_dir, 'fed/{}'.format(args.load_fed))
+ net_glob.load_state_dict(torch.load(fed_model_path))
+
+ total_num_layers = len(net_glob.weight_keys)
+ w_glob_keys = net_glob.weight_keys[total_num_layers - args.num_layers_keep:]
+ w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
+
+ num_param_glob = 0
+ num_param_local = 0
+ for key in net_glob.state_dict().keys():
+ num_param_local += net_glob.state_dict()[key].numel()
+ if key in w_glob_keys:
+ num_param_glob += net_glob.state_dict()[key].numel()
+ percentage_param = 100 * float(num_param_glob) / num_param_local
+ print('# Params: {} (local), {} (global); Percentage {:.2f} ({}/{})'.format(
+ num_param_local, num_param_glob, percentage_param, num_param_glob, num_param_local))
+ # generate list of local models for each user
+ net_local_list = []
+ for user_ix in range(args.num_users):
+ net_local_list.append(copy.deepcopy(net_glob))
+
+ # training
+ results_save_path = os.path.join(base_save_dir, 'results.csv')
+
+ loss_train = []
+ net_best = None
+ best_loss = None
+ best_acc = None
+ best_epoch = None
+
+ lr = args.lr
+ results = []
+
+ for iter in range(args.epochs):
+ w_glob = {}
+ loss_locals = []
+ m = max(int(args.frac * args.num_users), 1)
+ idxs_users = np.random.choice(range(args.num_users), m, replace=False)
+ # w_keys_epoch = net_glob.state_dict().keys() if (iter + 1) % 25 == 0 else w_glob_keys
+ w_keys_epoch = w_glob_keys
+
+ if args.verbose:
+ print("Round {}: lr: {:.6f}, {}".format(iter, lr, idxs_users))
+ for idx in idxs_users:
+ local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx])
+ net_local = net_local_list[idx]
+
+ w_local, loss = local.train(net=net_local.to(args.device), lr=lr)
+ loss_locals.append(copy.deepcopy(loss))
+
+ modules_glob = set([x.split('.')[0] for x in w_keys_epoch])
+ modules_all = net_local.__dict__['_modules']
+
+ # sum up weights
+ if len(w_glob) == 0:
+ w_glob = copy.deepcopy(net_glob.state_dict())
+ else:
+ for k in w_keys_epoch:
+ w_glob[k] += w_local[k]
+
+ if (iter+1) % int(args.num_users * args.frac):
+ lr *= args.lr_decay
+
+ # get weighted average for global weights
+ for k in w_keys_epoch:
+ w_glob[k] = torch.div(w_glob[k], m)
+
+ # copy weight to the global model (not really necessary)
+ net_glob.load_state_dict(w_glob)
+
+ # copy weights to each local model
+ for idx in range(args.num_users):
+ net_local = net_local_list[idx]
+ w_local = net_local.state_dict()
+ for k in w_keys_epoch:
+ w_local[k] = w_glob[k]
+
+ net_local.load_state_dict(w_local)
+
+ loss_avg = sum(loss_locals) / len(loss_locals)
+ loss_train.append(loss_avg)
+
+ # eval
+ acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test)
+ acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test)
+ print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format(
+ iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg))
+
+ if best_acc is None or acc_test_local > best_acc:
+ best_acc = acc_test_local
+ best_epoch = iter
+
+ for user in range(args.num_users):
+ model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user))
+ torch.save(net_local_list[user].state_dict(), model_save_path)
+
+ results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc, None, None]))
+ final_results = np.array(results)
+ final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'])
+ final_results.to_csv(results_save_path, index=False)
+
+ for user in range(args.num_users):
+ model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user))
+
+ net_local = net_local_list[user]
+ net_local.load_state_dict(torch.load(model_save_path))
+ acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(net_local_list, args, dataset_test)
+
+ print('Best model, iter: {}, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_epoch, best_acc, acc_test_ens_avg, acc_test_ens_maj))
+
+ results.append(np.array(['Final', None, None, best_acc, acc_test_ens_avg, acc_test_ens_maj]))
+ final_results = np.array(results)
+ final_results = pd.DataFrame(final_results,
+ columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'])
+ final_results.to_csv(results_save_path, index=False)
diff --git a/main_local.py b/main_local.py
new file mode 100644
index 0000000..792414b
--- /dev/null
+++ b/main_local.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+import copy
+import os
+import pickle
+import pandas as pd
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+
+from utils.options import args_parser
+from utils.train_utils import get_data, get_model
+from models.Update import DatasetSplit
+from models.test import test_img_local, test_img_local_all, test_img_avg_all, test_img_ensemble_all
+
+import pdb
+
+if __name__ == '__main__':
+ # parse args
+ args = args_parser()
+ args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
+
+ base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
+ args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)
+ if not os.path.exists(os.path.join(base_dir, 'local')):
+ os.makedirs(os.path.join(base_dir, 'local'), exist_ok=True)
+
+ dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
+ dict_save_path = os.path.join(base_dir, 'dict_users.pkl')
+ with open(dict_save_path, 'rb') as handle:
+ dict_users_train, dict_users_test = pickle.load(handle)
+
+ # build model
+ net_glob = get_model(args)
+ net_glob.train()
+
+ net_local_list = []
+ for user_ix in range(args.num_users):
+ net_local_list.append(copy.deepcopy(net_glob))
+
+ # training
+ results_save_path = os.path.join(base_dir, 'local/results.csv')
+
+ loss_train = []
+ net_best = None
+ best_loss = None
+ best_acc = None
+ best_epoch = None
+
+ lr = args.lr
+ results = []
+
+ criterion = nn.CrossEntropyLoss()
+
+ for user, net_local in enumerate(net_local_list):
+ model_save_path = os.path.join(base_dir, 'local/model_user{}.pt'.format(user))
+ net_best = None
+ best_acc = None
+
+ ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users_train[user]), batch_size=args.local_bs, shuffle=True)
+ optimizer = torch.optim.SGD(net_local.parameters(), lr=lr, momentum=0.5)
+ for iter in range(args.epochs):
+ for batch_idx, (images, labels) in enumerate(ldr_train):
+ images, labels = images.to(args.device), labels.to(args.device)
+ net_local.zero_grad()
+ log_probs = net_local(images)
+
+ loss = criterion(log_probs, labels)
+ loss.backward()
+ optimizer.step()
+
+ acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
+ if best_acc is None or acc_test > best_acc:
+ best_acc = acc_test
+ net_best = copy.deepcopy(net_local)
+ # torch.save(net_local_list[user].state_dict(), model_save_path)
+
+ print('User {}, Epoch {}, Acc {:.2f}'.format(user, iter, acc_test))
+
+ if iter > 50 and acc_test >= 99:
+ break
+
+ net_local_list[user] = net_best
+
+ acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test)
+ acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, net_local_list, args, dataset_test)
+ acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(net_local_list, args, dataset_test)
+
+ print('Final: acc: {:.2f}, acc (avg): {:.2f}, acc (ens,avg): {:.2f}, acc (ens,maj): {:.2f}'.format(acc_test_local, acc_test_avg, acc_test_ens_avg, acc_test_ens_maj))
+
+ final_results = np.array([[acc_test_local, acc_test_avg, acc_test_ens_avg, acc_test_ens_maj]])
+ final_results = pd.DataFrame(final_results, columns=['acc_test_local', 'acc_test_avg', 'acc_test_ens_avg', 'acc_test_ens_maj'])
+ final_results.to_csv(results_save_path, index=False)
diff --git a/main_mtl.py b/main_mtl.py
index 26f521c..c3ac4dc 100644
--- a/main_mtl.py
+++ b/main_mtl.py
@@ -7,19 +7,21 @@
# import matplotlib.pyplot as plt
import copy
import os
+import pickle
import itertools
import numpy as np
+import pandas as pd
+from tqdm import tqdm
from scipy.stats import mode
from torchvision import datasets, transforms, models
import torch
from torch import nn
-from utils.sampling import mnist_iid, mnist_noniid, cifar10_iid, cifar10_noniid
+from utils.train_utils import get_model, get_data
from utils.options import args_parser
from models.Update import LocalUpdateMTL
-from models.Nets import MLP, CNNMnist, CNNCifar, ResnetCifar
-from models.Fed import FedAvg
-from models.test import test_img, test_img_local
+from models.test import test_img, test_img_local, test_img_local_all, test_img_avg_all, test_img_ensemble_all
+
import pdb
@@ -28,93 +30,24 @@
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
- trans_mnist = transforms.Compose([transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))])
-
- if args.model == 'resnet':
- trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.Resize([256,256]),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])])
- trans_cifar_val = transforms.Compose([transforms.Resize([256,256]),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])])
- else:
- trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])])
- trans_cifar_val = transforms.Compose([transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])])
+ base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
+ args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)
- # load dataset and split users
- if args.dataset == 'mnist':
- dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist)
- dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist)
- # sample users
- if args.iid:
- dict_users_train = mnist_iid(dataset_train, args.num_users)
- dict_users_test = mnist_iid(dataset_test, args.num_users)
- else:
- dict_users_train, rand_set_all = mnist_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=300, train=True)
- dict_users_test, _ = mnist_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50, train=False, rand_set_all=rand_set_all)
- save_path = './save/{}/randset_lg_{}_iid{}_num{}_C{}.pt'.format(
- args.results_save, args.dataset, args.iid, args.num_users, args.frac)
- np.save(save_path, rand_set_all)
+ base_save_dir = os.path.join(base_dir, 'mtl')
+ if not os.path.exists(base_save_dir):
+ os.makedirs(base_save_dir, exist_ok=True)
- elif args.dataset == 'cifar10':
- dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar_train)
- dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar_val)
- if args.iid:
- dict_users_train = cifar10_iid(dataset_train, args.num_users)
- dict_users_test = cifar10_iid(dataset_test, args.num_users)
- else:
- dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250, train=True)
- dict_users_test, _ = cifar10_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50, train=False, rand_set_all=rand_set_all)
-
- elif args.dataset == 'cifar100':
- dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar_train)
- dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar_val)
- if args.iid:
- dict_users_train = cifar10_iid(dataset_train, args.num_users)
- dict_users_test = cifar10_iid(dataset_test, args.num_users)
- else:
- dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=1000, num_imgs=50, train=True)
- dict_users_test, _ = cifar10_noniid(dataset_test, args.num_users, num_shards=1000, num_imgs=10, train=False, rand_set_all=rand_set_all)
- else:
- exit('Error: unrecognized dataset')
- img_size = dataset_train[0][0].shape
+ dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
+ dict_save_path = os.path.join(base_dir, 'dict_users.pkl')
+ with open(dict_save_path, 'rb') as handle:
+ dict_users_train, dict_users_test = pickle.load(handle)
# build model
- if args.model == 'cnn' and args.dataset in ['cifar10', 'cifar100']:
- net_glob = CNNCifar(args=args).to(args.device)
- elif args.model == 'cnn' and args.dataset == 'mnist':
- net_glob = CNNMnist(args=args).to(args.device)
- elif args.model == 'resnet' and args.dataset in ['cifar10', 'cifar100']:
- net_glob = ResnetCifar(args=args).to(args.device)
- elif args.model == 'mlp':
- len_in = 1
- for x in img_size:
- len_in *= x
- net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
- else:
- exit('Error: unrecognized model')
+ net_glob = get_model(args)
+ net_glob.train()
print(net_glob)
net_glob.train()
- if args.load_fed:
- fed_model_path = './save/keep/fed_{}_{}_iid{}_num{}_C{}_le{}_gn{}.npy'.format(
- args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.grad_norm)
- if len(args.load_fed_name) > 0:
- fed_model_path = './save/keep/{}'.format(args.load_fed_name)
- net_glob.load_state_dict(torch.load(fed_model_path))
total_num_layers = len(net_glob.weight_keys)
w_glob_keys = net_glob.weight_keys[total_num_layers - args.num_layers_keep:]
@@ -129,6 +62,7 @@
percentage_param = 100 * float(num_param_glob) / num_param_local
print('# Params: {} (local), {} (global); Percentage {:.2f} ({}/{})'.format(
num_param_local, num_param_glob, percentage_param, num_param_glob, num_param_local))
+
# generate list of local models for each user
net_local_list = []
for user_ix in range(args.num_users):
@@ -136,104 +70,13 @@
criterion = nn.CrossEntropyLoss()
-
- def test_img_ensemble_all():
- probs_all = []
- preds_all = []
- for idx in range(args.num_users):
- net_local = net_local_list[idx]
- net_local.eval()
- # _, _, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx)
- acc, loss, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx)
- # print('Local model: {}, loss: {}, acc: {}'.format(idx, loss, acc))
- probs_all.append(probs.detach())
-
- preds = probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
- preds_all.append(preds)
-
- labels = np.array(dataset_test.targets)
- preds_probs = torch.mean(torch.stack(probs_all), dim=0)
-
- # ensemble metrics
- preds_avg = preds_probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
- loss_test = criterion(preds_probs, torch.tensor(labels).to(args.device)).item()
- acc_test = (preds_avg == labels).mean() * 100
-
- return loss_test, acc_test
-
- def test_img_local_all():
- acc_test_local = 0
- loss_test_local = 0
- for idx in range(args.num_users):
- net_local = net_local_list[idx]
- net_local.eval()
- a, b = test_img_local(net_local, dataset_test, args, user_idx=idx, idxs=dict_users_test[idx])
-
- acc_test_local += a
- loss_test_local += b
- acc_test_local /= args.num_users
- loss_test_local /= args.num_users
-
- return acc_test_local, loss_test_local
-
- def test_img_avg_all():
- net_glob_temp = copy.deepcopy(net_glob)
- w_keys_epoch = net_glob.state_dict().keys()
- w_glob_temp = {}
- for idx in range(args.num_users):
- net_local = net_local_list[idx]
- w_local = net_local.state_dict()
-
- if len(w_glob_temp) == 0:
- w_glob_temp = copy.deepcopy(w_local)
- else:
- for k in w_keys_epoch:
- w_glob_temp[k] += w_local[k]
-
- for k in w_keys_epoch:
- w_glob_temp[k] = torch.div(w_glob_temp[k], args.num_users)
- net_glob_temp.load_state_dict(w_glob_temp)
- acc_test_avg, loss_test_avg = test_img(net_glob_temp, dataset_test, args)
-
- return acc_test_avg, loss_test_avg
-
- if args.local_ep_pretrain > 0:
- # pretrain each local model
- pretrain_save_path = 'pretrain/{}/{}_{}/user_{}/ep_{}/'.format(args.model, args.dataset,
- 'iid' if args.iid else 'noniid', args.num_users,
- args.local_ep_pretrain)
- if not os.path.exists(pretrain_save_path):
- os.makedirs(pretrain_save_path)
-
- print("\nPretraining local models...")
- for idx in range(args.num_users):
- net_local = net_local_list[idx]
- net_local_path = os.path.join(pretrain_save_path, '{}.pt'.format(idx))
- if os.path.exists(net_local_path): # check if we have a saved model
- net_local.load_state_dict(torch.load(net_local_path))
- else:
- local = LocalUpdateMTL(args=args, dataset=dataset_train, idxs=dict_users_train[idx], pretrain=True)
- w_local, loss = local.train(net=net_local.to(args.device))
- print('Local model {}, Train Epoch Loss: {:.4f}'.format(idx, loss))
- torch.save(net_local.state_dict(), net_local_path)
-
- print("Getting initial loss and acc...")
- acc_test_local, loss_test_local = test_img_local_all()
- acc_test_avg, loss_test_avg = test_img_avg_all()
- loss_test, acc_test = test_img_ensemble_all()
-
- print('Initial Ensemble: Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}, Loss (ens) {:.3f}, Acc: (ens) {:.2f}, '.format(
- loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test))
-
# training
+ results_save_path = os.path.join(base_save_dir, 'results.csv')
+
loss_train = []
- cv_loss, cv_acc = [], []
- val_loss_pre, counter = 0, 0
net_best = None
- best_loss = None
- best_acc = None
- best_epoch = None
- val_acc_list, net_list = [], []
+ best_acc = np.ones(args.num_users) * -1
+ best_net_list = copy.deepcopy(net_local_list)
lr = args.lr
results = []
@@ -257,82 +100,49 @@ def test_img_avg_all():
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
W = torch.zeros((d, m)).cuda()
- for i in range(m):
- W_local = [net_local_list[i].state_dict()[key].flatten() for key in w_glob_keys]
+ # for idx, user in enumerate(idxs_users):
+ for idx, user in enumerate(range(10)):
+ W_local = [net_local_list[user].state_dict()[key].flatten() for key in w_glob_keys]
W_local = torch.cat(W_local)
- W[:, i] = W_local
+ W[:, idx] = W_local
- if args.verbose:
- print("Round {}: lr: {:.6f}, {}".format(iter, lr, idxs_users))
- for i, idx in enumerate(idxs_users):
- local = LocalUpdateMTL(args=args, dataset=dataset_train, idxs=dict_users_train[idx])
- net_local = net_local_list[idx]
+ for idx, user in enumerate(idxs_users):
+ local = LocalUpdateMTL(args=args, dataset=dataset_train, idxs=dict_users_train[user])
+ net_local = net_local_list[user]
w_local, loss = local.train(net=net_local.to(args.device), lr=lr,
- omega=omega, W_glob=W.clone(), i=i, w_glob_keys=w_glob_keys)
+ omega=omega, W_glob=W.clone(), idx=idx, w_glob_keys=w_glob_keys)
loss_locals.append(copy.deepcopy(loss))
- if (iter+1) % int(args.num_users * args.frac):
- lr *= args.lr_decay
-
loss_avg = sum(loss_locals) / len(loss_locals)
loss_train.append(loss_avg)
# eval
- acc_test_local, loss_test_local = test_img_local_all()
- acc_test_avg, loss_test_avg = test_img_avg_all()
+ acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=True)
- # if (iter + 1) % args.test_freq == 0: # this takes too much time, so we run it less frequently
- if (iter + 1) > 2000:
- loss_test, acc_test = test_img_ensemble_all()
- print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}, Loss (ens) {:.3f}, Acc: (ens) {:.2f}, '.format(
- iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test))
- results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test]))
+ for user in range(args.num_users):
+ if acc_test_local[user] > best_acc[user]:
+ best_acc[user] = acc_test_local[user]
+ best_net_list[user] = copy.deepcopy(net_local_list[user])
- else:
- print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format(
- iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg))
- results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, np.nan, np.nan]))
+ model_save_path = os.path.join(base_save_dir, 'model_user{}.pt'.format(user))
+ torch.save(best_net_list[user].state_dict(), model_save_path)
- final_results = np.array(results)
- results_save_path = './log/lg_{}_{}_keep{}_iid{}_num{}_C{}_le{}_gn{}_pt{}_load{}_tfreq{}.npy'.format(
- args.dataset, args.model, args.num_layers_keep, args.iid, args.num_users, args.frac,
- args.local_ep, args.grad_norm, args.local_ep_pretrain, args.load_fed, args.test_freq)
- np.save(results_save_path, final_results)
-
- if best_acc is None or acc_test_local > best_acc:
- best_acc = acc_test_local
- best_epoch = iter
-
- for user in range(args.num_users):
- model_save_path = './save/{}/{}_{}_iid{}/'.format(
- args.results_save, args.dataset, 0, args.iid)
- if not os.path.exists(model_save_path):
- os.makedirs(model_save_path)
-
- model_save_path = './save/{}/{}_{}_iid{}/user{}.pt'.format(
- args.results_save, args.dataset, 0, args.iid, user)
- torch.save(net_local_list[user].state_dict(), model_save_path)
+ acc_test_local, loss_test_local = test_img_local_all(best_net_list, args, dataset_test, dict_users_test)
+ acc_test_avg, loss_test_avg = test_img_avg_all(net_glob, best_net_list, args, dataset_test)
+ print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format(
+ iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg))
- for user in range(args.num_users):
- model_save_path = './save/{}/{}_{}_iid{}/user{}.pt'.format(
- args.results_save, args.dataset, 0, args.iid, user)
-
- net_local = net_local_list[idx]
- net_local.load_state_dict(torch.load(model_save_path))
- loss_test, acc_test = test_img_ensemble_all()
-
- print('Best model, iter: {}, acc: {}, acc (ens): {}'.format(best_epoch, best_acc, acc_test))
+ results.append(np.array([iter, acc_test_local, acc_test_avg, best_acc.mean(), None, None]))
+ final_results = np.array(results)
+ final_results = pd.DataFrame(final_results, columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'])
+ final_results.to_csv(results_save_path, index=False)
- # plot loss curve
- # plt.figure()
- # plt.plot(range(len(loss_train)), loss_train)
- # plt.ylabel('train_loss')
- # plt.savefig('./log/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))
+ acc_test_ens_avg, loss_test, acc_test_ens_maj = test_img_ensemble_all(best_net_list, args, dataset_test)
+ print('Best model, acc (local): {}, acc (ens,avg): {}, acc (ens,maj): {}'.format(best_acc, acc_test_ens_avg, acc_test_ens_maj))
- # testing
- net_glob.eval()
- acc_train, loss_train = test_img(net_glob, dataset_train, args)
- acc_test, loss_test = test_img(net_glob, dataset_test, args)
- print("Training accuracy: {:.2f}".format(acc_train))
- print("Testing accuracy: {:.2f}".format(acc_test))
\ No newline at end of file
+ results.append(np.array(['Final', None, None, best_acc.mean(), acc_test_ens_avg, acc_test_ens_maj]))
+ final_results = np.array(results)
+ final_results = pd.DataFrame(final_results,
+ columns=['epoch', 'acc_test_local', 'acc_test_avg', 'best_acc_local', 'acc_test_ens_avg', 'acc_test_ens_maj'])
+ final_results.to_csv(results_save_path, index=False)
\ No newline at end of file
diff --git a/models/Update.py b/models/Update.py
index 449f454..a2b1f2a 100644
--- a/models/Update.py
+++ b/models/Update.py
@@ -5,7 +5,9 @@
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
import math
+import pdb
class DatasetSplit(Dataset):
def __init__(self, dataset, idxs):
@@ -51,16 +53,6 @@ def train(self, net, idx=-1, lr=0.1):
batch_loss.append(loss.item())
- if not self.pretrain and self.args.verbose and batch_idx % 300 == 0:
- if idx < 0:
- print('Update Epoch: {} [{}/{} ({:.0f}%)], Epoch Loss: {:.4f}, Batch Loss: {:.4f}'.format(
- iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train),
- sum(batch_loss)/len(batch_loss), loss.item()))
- else:
- print('Local model {}, Update Epoch: {} [{}/{} ({:.0f}%)], Epoch Loss: {:.4f}, Batch Loss: {:.4f}'.format(
- idx, iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train),
- sum(batch_loss)/len(batch_loss), loss.item()))
-
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
@@ -74,7 +66,7 @@ def __init__(self, args, dataset=None, idxs=None, pretrain=False):
self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
self.pretrain = pretrain
- def train(self, net, idx=-1, lr=0.1, omega=None, W_glob=None, i=None, w_glob_keys=None):
+ def train(self, net, lr=0.1, omega=None, W_glob=None, idx=None, w_glob_keys=None):
net.train()
# train and update
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.5)
@@ -98,12 +90,12 @@ def train(self, net, idx=-1, lr=0.1, omega=None, W_glob=None, i=None, w_glob_key
W_local = [net.state_dict(keep_vars=True)[key].flatten() for key in w_glob_keys]
W_local = torch.cat(W_local)
- W[:, i] = W_local
+ W[:, idx] = W_local
loss_regularizer = 0
loss_regularizer += W.norm() ** 2
- k = 10000
+ k = 4000
for i in range(W.shape[0] // k):
x = W[i * k:(i+1) * k, :]
loss_regularizer += x.mm(omega).mm(x.T).trace()
@@ -116,16 +108,6 @@ def train(self, net, idx=-1, lr=0.1, omega=None, W_glob=None, i=None, w_glob_key
batch_loss.append(loss.item())
- if not self.pretrain and self.args.verbose and batch_idx % 300 == 0:
- if idx < 0:
- print('Update Epoch: {} [{}/{} ({:.0f}%)], Epoch Loss: {:.4f}, Batch Loss: {:.4f}'.format(
- iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train),
- sum(batch_loss)/len(batch_loss), loss.item()))
- else:
- print('Local model {}, Update Epoch: {} [{}/{} ({:.0f}%)], Epoch Loss: {:.4f}, Batch Loss: {:.4f}'.format(
- idx, iter, batch_idx * len(images), len(self.ldr_train.dataset), 100. * batch_idx / len(self.ldr_train),
- sum(batch_loss)/len(batch_loss), loss.item()))
-
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
diff --git a/models/test.py b/models/test.py
index 084d22a..2df3fcc 100644
--- a/models/test.py
+++ b/models/test.py
@@ -2,10 +2,14 @@
# -*- coding: utf-8 -*-
# @python: 3.6
+import copy
+import numpy as np
+from scipy import stats
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
+import pdb
class DatasetSplit(Dataset):
def __init__(self, dataset, idxs):
@@ -19,7 +23,6 @@ def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return image, label
-
def test_img(net_g, datatest, args, return_probs=False, user_idx=-1):
net_g.eval()
# testing
@@ -84,3 +87,72 @@ def test_img_local(net_g, dataset, args, user_idx=-1, idxs=None):
user_idx, test_loss, correct, len(data_loader.dataset), accuracy))
return accuracy, test_loss
+
+def test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=False):
+ acc_test_local = np.zeros(args.num_users)
+ loss_test_local = np.zeros(args.num_users)
+ for idx in range(args.num_users):
+ net_local = net_local_list[idx]
+ net_local.eval()
+ a, b = test_img_local(net_local, dataset_test, args, user_idx=idx, idxs=dict_users_test[idx])
+
+ acc_test_local[idx] = a
+ loss_test_local[idx] = b
+
+ if return_all:
+ return acc_test_local, loss_test_local
+ return acc_test_local.mean(), loss_test_local.mean()
+
+def test_img_avg_all(net_glob, net_local_list, args, dataset_test, return_net=False):
+ net_glob_temp = copy.deepcopy(net_glob)
+ w_keys_epoch = net_glob.state_dict().keys()
+ w_glob_temp = {}
+ for idx in range(args.num_users):
+ net_local = net_local_list[idx]
+ w_local = net_local.state_dict()
+
+ if len(w_glob_temp) == 0:
+ w_glob_temp = copy.deepcopy(w_local)
+ else:
+ for k in w_keys_epoch:
+ w_glob_temp[k] += w_local[k]
+
+ for k in w_keys_epoch:
+ w_glob_temp[k] = torch.div(w_glob_temp[k], args.num_users)
+ net_glob_temp.load_state_dict(w_glob_temp)
+ acc_test_avg, loss_test_avg = test_img(net_glob_temp, dataset_test, args)
+
+ if return_net:
+ return acc_test_avg, loss_test_avg, net_glob_temp
+ return acc_test_avg, loss_test_avg
+
+criterion = nn.CrossEntropyLoss()
+
+def test_img_ensemble_all(net_local_list, args, dataset_test):
+ probs_all = []
+ preds_all = []
+ for idx in range(args.num_users):
+ net_local = net_local_list[idx]
+ net_local.eval()
+ # _, _, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx)
+ acc, loss, probs = test_img(net_local, dataset_test, args, return_probs=True, user_idx=idx)
+ # print('Local model: {}, loss: {}, acc: {}'.format(idx, loss, acc))
+ probs_all.append(probs.detach())
+
+ preds = probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
+ preds_all.append(preds)
+
+ labels = np.array(dataset_test.targets)
+ preds_probs = torch.mean(torch.stack(probs_all), dim=0)
+
+ # ensemble (avg) metrics
+ preds_avg = preds_probs.data.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
+ loss_test = criterion(preds_probs, torch.tensor(labels).to(args.device)).item()
+ acc_test_avg = (preds_avg == labels).mean() * 100
+
+ # ensemble (maj)
+ preds_all = np.array(preds_all).T
+ preds_maj = stats.mode(preds_all, axis=1)[0].reshape(-1)
+ acc_test_maj = (preds_maj == labels).mean() * 100
+
+ return acc_test_avg, loss_test, acc_test_maj
\ No newline at end of file
diff --git a/scripts/cifar10_ablation.sh b/scripts/cifar10_ablation.sh
new file mode 100755
index 0000000..14e4178
--- /dev/null
+++ b/scripts/cifar10_ablation.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+SHARDS=2
+while getopts s:n: option; do
+ case "${option}" in
+ s) SHARDS=${OPTARG};;
+ esac
+done
+
+for RUN in 1 2 3 4 5; do
+ for NUM_USERS in 20 50 200; do
+ python3 main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN}
+
+ python3 main_local.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN}
+
+ for FED_MODEL in 1000 1200 1400 1600 1800; do
+ python3 main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 \
+ --results_save run${RUN} --load_fed best_${FED_MODEL}.pt
+ done
+ done
+done
+
diff --git a/scripts/main_fed_cifar10.sh b/scripts/main_fed_cifar10.sh
new file mode 100755
index 0000000..f14665a
--- /dev/null
+++ b/scripts/main_fed_cifar10.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+NUM_USERS=100
+SHARDS=2
+while getopts s:n: option; do
+ case "${option}" in
+ s) SHARDS=${OPTARG};;
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+
+for RUN in 1 2 3 4 5; do
+ python3 main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN}
+ python3 main_local.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN}
+done
+
diff --git a/scripts/main_fed_cifar10_iid.sh b/scripts/main_fed_cifar10_iid.sh
new file mode 100755
index 0000000..a1929d7
--- /dev/null
+++ b/scripts/main_fed_cifar10_iid.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+NUM_USERS=100
+while getopts s:n: option; do
+ case "${option}" in
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+
+for RUN in 1 2 3 4 5; do
+ python3 main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN} --iid
+ python3 main_local.py --dataset cifar10 --model cnn --num_classes 10 --epochs 400 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 50 --results_save run${RUN} --iid
+done
+
diff --git a/scripts/main_fed_mnist.sh b/scripts/main_fed_mnist.sh
new file mode 100755
index 0000000..dbe4002
--- /dev/null
+++ b/scripts/main_fed_mnist.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+NUM_USERS=100
+SHARDS=2
+while getopts s:n: option; do
+ case "${option}" in
+ s) SHARDS=${OPTARG};;
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+for RUN in 1 2 3 4 5; do
+ python3 main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN}
+ python3 main_local.py --dataset mnist --model mlp --num_classes 10 --epochs 100 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN}
+done
+
diff --git a/scripts/main_fed_mnist_iid.sh b/scripts/main_fed_mnist_iid.sh
new file mode 100755
index 0000000..754486d
--- /dev/null
+++ b/scripts/main_fed_mnist_iid.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+NUM_USERS=100
+while getopts s:n: option; do
+ case "${option}" in
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+for RUN in 1 2 3 4 5; do
+ python3 main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN} --iid
+ python3 main_local.py --dataset mnist --model mlp --num_classes 10 --epochs 400 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN} --iid
+done
+
diff --git a/scripts/main_lg_cifar10.sh b/scripts/main_lg_cifar10.sh
new file mode 100755
index 0000000..a09a0f3
--- /dev/null
+++ b/scripts/main_lg_cifar10.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+NUM_USERS=100
+SHARDS=2
+while getopts s:n: option; do
+ case "${option}" in
+ s) SHARDS=${OPTARG};;
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+for FED_MODEL in 800 1000 1200 1400 1600 1800; do
+ for RUN in 1 2 3 4 5; do
+ python3 main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 \
+ --results_save run${RUN} --load_fed best_${FED_MODEL}.pt
+ done
+done
+
diff --git a/scripts/main_lg_cifar10_iid.sh b/scripts/main_lg_cifar10_iid.sh
new file mode 100755
index 0000000..3abb075
--- /dev/null
+++ b/scripts/main_lg_cifar10_iid.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+NUM_USERS=100
+while getopts s:n: option; do
+ case "${option}" in
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+for FED_MODEL in 800 1000 1200 1400 1600 1800; do
+ for RUN in 1 2 3 4 5; do
+ python3 main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 \
+ --results_save run${RUN} --load_fed best_${FED_MODEL}.pt --iid
+ done
+done
+
diff --git a/scripts/main_lg_mnist.sh b/scripts/main_lg_mnist.sh
new file mode 100755
index 0000000..65344b8
--- /dev/null
+++ b/scripts/main_lg_mnist.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+NUM_USERS=100
+SHARDS=2
+while getopts s:n: option; do
+ case "${option}" in
+ s) SHARDS=${OPTARG};;
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+for FED_MODEL in 400 500 600 700 800; do
+ for RUN in 1 2 3 4 5; do
+ python3 main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 \
+ --results_save run${RUN} --load_fed best_${FED_MODEL}.pt
+ done
+done
+
diff --git a/scripts/main_lg_mnist_iid.sh b/scripts/main_lg_mnist_iid.sh
new file mode 100755
index 0000000..6c4da70
--- /dev/null
+++ b/scripts/main_lg_mnist_iid.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+NUM_USERS=100
+while getopts s:n: option; do
+ case "${option}" in
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+for FED_MODEL in 400 500 600 700 800; do
+ for RUN in 1 2 3 4 5; do
+ python3 main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user 10 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 \
+ --results_save run${RUN} --load_fed best_${FED_MODEL}.pt --iid
+ done
+done
+
diff --git a/scripts/main_mtl_cifar10.sh b/scripts/main_mtl_cifar10.sh
new file mode 100755
index 0000000..89baa3a
--- /dev/null
+++ b/scripts/main_mtl_cifar10.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+NUM_USERS=100
+SHARDS=2
+while getopts s:n: option; do
+ case "${option}" in
+ s) SHARDS=${OPTARG};;
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+for RUN in 1 2 3 4 5; do
+ python3 main_mtl.py --dataset cifar10 --model cnn --num_classes 10 --epochs 1800 --lr 0.1 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 5 \
+ --results_save run${RUN}
+done
+
diff --git a/scripts/main_mtl_mnist.sh b/scripts/main_mtl_mnist.sh
new file mode 100755
index 0000000..9282ff4
--- /dev/null
+++ b/scripts/main_mtl_mnist.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+NUM_USERS=100
+SHARDS=2
+while getopts s:n: option; do
+ case "${option}" in
+ s) SHARDS=${OPTARG};;
+ n) NUM_USERS=${OPTARG};;
+ esac
+done
+
+for RUN in 1 2 3 4 5; do
+ python3 main_mtl.py --dataset mnist --model mlp --num_classes 10 --epochs 800 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 5 \
+ --results_save run${RUN}
+done
+
diff --git a/scripts/mnist_ablation.sh b/scripts/mnist_ablation.sh
new file mode 100755
index 0000000..095bff6
--- /dev/null
+++ b/scripts/mnist_ablation.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+SHARDS=2
+while getopts s:n: option; do
+ case "${option}" in
+ s) SHARDS=${OPTARG};;
+ esac
+done
+
+for RUN in 1 2 3 4 5; do
+ for NUM_USERS in 20 50 200; do
+ python3 main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN}
+
+ python3 main_local.py --dataset mnist --model mlp --num_classes 10 --epochs 100 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --results_save run${RUN}
+
+ for FED_MODEL in 400 500 600 700 800; do
+ for RUN in 1 2 3 4 5; do
+ python3 main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 \
+ --num_users ${NUM_USERS} --shard_per_user ${SHARDS} --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 \
+ --results_save run${RUN} --load_fed best_${FED_MODEL}.pt
+ done
+ done
+ done
+done
+
diff --git a/test_local_numusers.py b/test_local_numusers.py
new file mode 100644
index 0000000..5647914
--- /dev/null
+++ b/test_local_numusers.py
@@ -0,0 +1,109 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Python version: 3.6
+
+import copy
+import os
+import pickle
+import pandas as pd
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+
+from utils.options import args_parser
+from utils.train_utils import get_data, get_model
+from models.Update import DatasetSplit
+from models.test import test_img_local, test_img_local_all, test_img_avg_all, test_img_ensemble_all
+
+from torchvision import datasets, transforms
+from models.Nets import MLP, CNNMnist, CNNCifar
+from utils.sampling import iid, noniid, noniid_replace
+
+import pdb
+
+results = {}
+
+# manually set arguments
+args = args_parser()
+args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
+args.num_classes = 10
+args.epochs = 100
+if args.dataset == "mnist":
+ args.local_bs = 10
+else:
+ args.local_bs = 50
+
+early_stopping = 20
+
+args.shard_per_user = 2
+for num_users in [5, 10, 20, 50, 100, 200, 500, 1000]:
+ results[num_users] = []
+ for run in range(10):
+ args.num_users = num_users
+
+ # dataset
+ dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
+
+ # model
+ net_glob = get_model(args)
+ net_glob.train()
+
+ net_local_list = []
+ for user_ix in range(args.num_users):
+ net_local_list.append(copy.deepcopy(net_glob))
+
+ criterion = nn.CrossEntropyLoss()
+ lr = args.lr
+
+ # run each local model
+ for user, net_local in enumerate(net_local_list):
+ net_best = None
+ best_acc = None
+
+ ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users_train[user]), batch_size=args.local_bs, shuffle=True)
+ optimizer = torch.optim.SGD(net_local.parameters(), lr=lr, momentum=0.5)
+
+ epochs_since_improvement = 0
+ for iter in range(args.epochs):
+ for batch_idx, (images, labels) in enumerate(ldr_train):
+ images, labels = images.to(args.device), labels.to(args.device)
+ net_local.zero_grad()
+ log_probs = net_local(images)
+
+ loss = criterion(log_probs, labels)
+ loss.backward()
+ optimizer.step()
+
+ acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
+
+ epochs_since_improvement += 1
+ if best_acc is None or acc_test > best_acc:
+ best_acc = acc_test
+ net_best = copy.deepcopy(net_local)
+ epochs_since_improvement = 0
+
+ print('User {}, Epoch {}, Acc {:.2f}'.format(user, iter, acc_test))
+ if epochs_since_improvement >= early_stopping:
+ print("Early stopping...")
+ break
+
+ net_local_list[user] = net_best
+
+ acc_test_local, loss_test_local = test_img_local_all(net_local_list, args, dataset_test, dict_users_test)
+ results[num_users].append(acc_test_local)
+
+ results_save_path = "save/{}/test_local.pkl".format(args.dataset)
+ with open(results_save_path, 'wb') as file:
+ pickle.dump(results, file)
+
+import pickle
+import numpy as np
+
+# with open("save/mnist/test_local.pkl",'rb') as file: x = pickle.load(file)
+with open("save/cifar10/test_local.pkl",'rb') as file: x = pickle.load(file)
+
+for key, value in x.items(): print(key, np.array(value).mean())
+
+for key, value in x.items(): print(key, np.array(value).std())
+
diff --git a/utils/options.py b/utils/options.py
index 292cf7e..d61052b 100644
--- a/utils/options.py
+++ b/utils/options.py
@@ -9,6 +9,7 @@ def args_parser():
# federated arguments
parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
+ parser.add_argument('--shard_per_user', type=int, default=2, help="classes per user")
parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
diff --git a/utils/sampling.py b/utils/sampling.py
index 920689b..65cae04 100755
--- a/utils/sampling.py
+++ b/utils/sampling.py
@@ -2,9 +2,12 @@
# -*- coding: utf-8 -*-
# Python version: 3.6
-
+import math
+import random
+from itertools import permutations
import numpy as np
-from torchvision import datasets, transforms
+import torch
+import pdb
def fair_iid(dataset, num_users):
"""
@@ -64,7 +67,7 @@ def fair_noniid(train_data, num_users, num_shards=200, num_imgs=300, train=True,
return dict_users, rand_set_all
-def mnist_iid(dataset, num_users):
+def iid(dataset, num_users):
"""
Sample I.I.D. client data from MNIST dataset
:param dataset:
@@ -78,140 +81,95 @@ def mnist_iid(dataset, num_users):
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
-
-def mnist_noniid(dataset, num_users, num_shards=200, num_imgs=300, train=True, rand_set_all=[]):
+def noniid(dataset, num_users, shard_per_user, rand_set_all=[]):
"""
Sample non-I.I.D client data from MNIST dataset
:param dataset:
:param num_users:
:return:
"""
-
- assert num_shards % num_users == 0
- shard_per_user = int(num_shards / num_users)
-
- idx_shard = [i for i in range(num_shards)]
dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
- idxs = np.arange(num_shards*num_imgs)
- labels = np.array(dataset.targets)
-
- assert num_shards * num_imgs == len(labels)
- # sort labels
- idxs_labels = np.vstack((idxs, labels))
- idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
- idxs = idxs_labels[0,:]
+ idxs_dict = {}
+ for i in range(len(dataset)):
+ label = torch.tensor(dataset.targets[i]).item()
+ if label not in idxs_dict.keys():
+ idxs_dict[label] = []
+ idxs_dict[label].append(i)
+
+ num_classes = len(np.unique(dataset.targets))
+ shard_per_class = int(shard_per_user * num_users / num_classes)
+ for label in idxs_dict.keys():
+ x = idxs_dict[label]
+ num_leftover = len(x) % shard_per_class
+ leftover = x[-num_leftover:] if num_leftover > 0 else []
+ x = np.array(x[:-num_leftover]) if num_leftover > 0 else np.array(x)
+ x = x.reshape((shard_per_class, -1))
+ x = list(x)
+
+ for i, idx in enumerate(leftover):
+ x[i] = np.concatenate([x[i], [idx]])
+ idxs_dict[label] = x
- # divide and assign
if len(rand_set_all) == 0:
- for i in range(num_users):
- rand_set = set(np.random.choice(idx_shard, shard_per_user, replace=False))
- for rand in rand_set:
- rand_set_all.append(rand)
-
- idx_shard = list(set(idx_shard) - rand_set) # remove shards from possible choices for other users
- for rand in rand_set:
- dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-
- else: # this only works if the train and test set have the same distribution of labels
- for i in range(num_users):
- rand_set = rand_set_all[i*shard_per_user: (i+1)*shard_per_user]
- for rand in rand_set:
- dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-
- return dict_users, rand_set_all
-
-
-def cifar10_iid(dataset, num_users):
- """
- Sample I.I.D. client data from CIFAR10 dataset
- :param dataset:
- :param num_users:
- :return: dict of image index
- """
- num_items = int(len(dataset)/num_users)
- dict_users, all_idxs = {}, [i for i in range(len(dataset))]
- for i in range(num_users):
- dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
- all_idxs = list(set(all_idxs) - dict_users[i])
- return dict_users
-
-'''
-def cifar10_noniid(dataset, num_users):
- """
- Sample non-I.I.D client data from MNIST dataset
- :param dataset:
- :param num_users:
- :return:
- """
- num_shards, num_imgs = 200, 250
- idx_shard = [i for i in range(num_shards)]
- dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
- idxs = np.arange(num_shards*num_imgs)
- labels = np.array(dataset.train_labels)
-
- # sort labels
- idxs_labels = np.vstack((idxs, labels))
- idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
- idxs = idxs_labels[0,:]
+ rand_set_all = list(range(num_classes)) * shard_per_class
+ random.shuffle(rand_set_all)
+ rand_set_all = np.array(rand_set_all).reshape((num_users, -1))
# divide and assign
for i in range(num_users):
- rand_set = set(np.random.choice(idx_shard, 2, replace=False))
- idx_shard = list(set(idx_shard) - rand_set)
- for rand in rand_set:
- dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
- return dict_users
-'''
+ rand_set_label = rand_set_all[i]
+ rand_set = []
+ for label in rand_set_label:
+ idx = np.random.choice(len(idxs_dict[label]), replace=False)
+ rand_set.append(idxs_dict[label].pop(idx))
+ dict_users[i] = np.concatenate(rand_set)
+
+ test = []
+ for key, value in dict_users.items():
+ x = np.unique(torch.tensor(dataset.targets)[value])
+ assert(len(x)) <= shard_per_user
+ test.append(value)
+ test = np.concatenate(test)
+ assert(len(test) == len(dataset))
+ assert(len(set(list(test))) == len(dataset))
+
+ return dict_users, rand_set_all
-def cifar10_noniid(dataset, num_users, num_shards=200, num_imgs=250, train=True, rand_set_all=[]):
+def noniid_replace(dataset, num_users, shard_per_user, rand_set_all=[]):
"""
Sample non-I.I.D client data from MNIST dataset
:param dataset:
:param num_users:
:return:
"""
-
- assert num_shards % num_users == 0
- shard_per_user = int(num_shards / num_users)
-
- idx_shard = [i for i in range(num_shards)]
+ imgs_per_shard = int(len(dataset) / (num_users * shard_per_user))
dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
- idxs = np.arange(num_shards*num_imgs)
- labels = np.array(dataset.targets)
-
- assert num_shards * num_imgs == len(labels)
- # sort labels
- idxs_labels = np.vstack((idxs, labels))
- idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
- idxs = idxs_labels[0,:]
+ idxs_dict = {}
+ for i in range(len(dataset)):
+ label = torch.tensor(dataset.targets[i]).item()
+ if label not in idxs_dict.keys():
+ idxs_dict[label] = []
+ idxs_dict[label].append(i)
- # divide and assign
+ num_classes = len(np.unique(dataset.targets))
if len(rand_set_all) == 0:
for i in range(num_users):
- rand_set = set(np.random.choice(idx_shard, shard_per_user, replace=False))
- for rand in rand_set:
- rand_set_all.append(rand)
-
- idx_shard = list(set(idx_shard) - rand_set) # remove shards from possible choices for other users
- for rand in rand_set:
- dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
+ x = np.random.choice(np.arange(num_classes), shard_per_user, replace=False)
+ rand_set_all.append(x)
- else: # this only works if the train and test set have the same distribution of labels
- for i in range(num_users):
- rand_set = rand_set_all[i*shard_per_user: (i+1)*shard_per_user]
- for rand in rand_set:
- dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-
- return dict_users, rand_set_all
-
-
-if __name__ == '__main__':
- dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ]))
- num = 100
- d = mnist_noniid(dataset_train, num)
+ # divide and assign
+ for i in range(num_users):
+ rand_set_label = rand_set_all[i]
+ rand_set = []
+ for label in rand_set_label:
+ pdb.set_trace()
+ x = np.random.choice(idxs_dict[label], imgs_per_shard, replace=False)
+ rand_set.append(x)
+ dict_users[i] = np.concatenate(rand_set)
+
+ for key, value in dict_users.items():
+ assert(len(np.unique(torch.tensor(dataset.targets)[value]))) == shard_per_user
+
+ return dict_users, rand_set_all
\ No newline at end of file
diff --git a/utils/train_utils.py b/utils/train_utils.py
index 2c227d3..8015190 100644
--- a/utils/train_utils.py
+++ b/utils/train_utils.py
@@ -1,46 +1,59 @@
from torchvision import datasets, transforms
from models.Nets import MLP, CNNMnist, CNNCifar
-from utils.sampling import mnist_iid, mnist_noniid, cifar10_iid, cifar10_noniid
+from utils.sampling import iid, noniid
-def get_data(args, rand_set_all=[]):
- trans_mnist = transforms.Compose([transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))])
- trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])])
- trans_cifar_val = transforms.Compose([transforms.ToTensor(),
+trans_mnist = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,))])
+trans_cifar10_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
+trans_cifar10_val = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])])
+trans_cifar100_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.507, 0.487, 0.441],
+ std=[0.267, 0.256, 0.276])])
+trans_cifar100_val = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize(mean=[0.507, 0.487, 0.441],
+ std=[0.267, 0.256, 0.276])])
+def get_data(args):
if args.dataset == 'mnist':
dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users
if args.iid:
- dict_users_train = mnist_iid(dataset_train, args.num_users)
- dict_users_test = mnist_iid(dataset_test, args.num_users)
+ dict_users_train = iid(dataset_train, args.num_users)
+ dict_users_test = iid(dataset_test, args.num_users)
else:
- dict_users_train, rand_set_all = mnist_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=300,
- train=True, rand_set_all=rand_set_all)
- dict_users_test, rand_set_all = mnist_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50,
- train=False, rand_set_all=rand_set_all)
+ dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user)
+ dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all)
elif args.dataset == 'cifar10':
- dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar_train)
- dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar_val)
+ dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar10_train)
+ dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar10_val)
if args.iid:
- dict_users_train = cifar10_iid(dataset_train, args.num_users)
- dict_users_test = cifar10_iid(dataset_test, args.num_users)
+ dict_users_train = iid(dataset_train, args.num_users)
+ dict_users_test = iid(dataset_test, args.num_users)
else:
- dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250,
- train=True, rand_set_all=rand_set_all)
- dict_users_test, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250,
- train=True, rand_set_all=rand_set_all)
+ dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user)
+ dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all)
+ elif args.dataset == 'cifar100':
+ dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar100_train)
+ dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar100_val)
+ if args.iid:
+ dict_users_train = iid(dataset_train, args.num_users)
+ dict_users_test = iid(dataset_test, args.num_users)
+ else:
+ dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user)
+ dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all)
else:
exit('Error: unrecognized dataset')
- return dataset_train, dataset_test, dict_users_train, dict_users_test, rand_set_all
+ return dataset_train, dataset_test, dict_users_train, dict_users_test
def get_model(args):
if args.model == 'cnn' and args.dataset in ['cifar10', 'cifar100']: