diff --git a/main_fed.py b/main_fed.py index b9b0b33..428da6a 100644 --- a/main_fed.py +++ b/main_fed.py @@ -2,19 +2,13 @@ # -*- coding: utf-8 -*- # Python version: 3.6 -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt import copy import numpy as np -from torchvision import datasets, transforms, models import torch -from utils.sampling import mnist_iid, mnist_noniid, cifar10_iid, cifar10_noniid from utils.options import args_parser +from utils.train_utils import get_data, get_model from models.Update import LocalUpdate -from models.Nets import MLP, CNNMnist, CNNCifar, ResnetCifar -from models.Fed import FedAvg from models.test import test_img import os @@ -25,142 +19,56 @@ args = args_parser() args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') - if not os.path.exists('./log/{}/'.format(args.results_save)): - os.makedirs('./log/{}/'.format(args.results_save)) - if not os.path.exists('./save/{}/{}/'.format(args.results_save, args.dataset)): - os.makedirs('./save/{}/{}/'.format(args.results_save, args.dataset)) - - trans_mnist = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - - if args.model == 'resnet': - trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.Resize([256,256]), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - trans_cifar_val = transforms.Compose([transforms.Resize([256,256]), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - else: - trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - trans_cifar_val = transforms.Compose([transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - - # load dataset and split users - if args.dataset == 'mnist': - dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist) - dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist) - # sample users - if args.iid: - dict_users = mnist_iid(dataset_train, args.num_users) - else: - dict_users, rand_set_all = mnist_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=300, - train=True) - save_path = './save/{}/randset_fed_{}_iid{}_num{}_C{}.pt'.format( - args.results_save, args.dataset, args.iid, args.num_users, args.frac) - np.save(save_path, rand_set_all) - elif args.dataset == 'cifar10': - dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar_train) - dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar_val) - if args.iid: - dict_users = cifar10_iid(dataset_train, args.num_users) - else: - dict_users, _ = cifar10_noniid(dataset_train, args.num_users) - # exit('Error: only consider IID setting in CIFAR10') - elif args.dataset == 'cifar100': - dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar_train) - dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar_val) - if args.iid: - dict_users = cifar10_iid(dataset_train, args.num_users) - else: - dict_users, _ = cifar10_noniid(dataset_train, args.num_users, num_shards=1000, num_imgs=50, train=True) - else: - exit('Error: unrecognized dataset') - img_size = dataset_train[0][0].shape + base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/{}/'.format( + args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.results_save) + if not os.path.exists(os.path.join(base_dir, 'fed')): + os.makedirs(os.path.join(base_dir, 'fed'), exist_ok=True) + + dataset_train, dataset_test, dict_users_train, _, rand_set_all = get_data(args) + rand_save_path = os.path.join(base_dir, 'randset.npy') + np.save(rand_save_path, rand_set_all) # build model - if args.model == 'cnn' and args.dataset in ['cifar10', 'cifar100']: - net_glob = CNNCifar(args=args).to(args.device) - elif args.model == 'cnn' and args.dataset == 'mnist': - net_glob = CNNMnist(args=args).to(args.device) - elif args.model == 'resnet' and args.dataset in ['cifar10', 'cifar100']: - net_glob = ResnetCifar(args=args).to(args.device) - elif args.model == 'mlp': - len_in = 1 - for x in img_size: - len_in *= x - net_glob = MLP(dim_in=len_in, dim_hidden=256, dim_out=args.num_classes).to(args.device) - elif args.model == 'mlp_orig': - len_in = 1 - for x in img_size: - len_in *= x - net_glob = MLP(dim_in=len_in, dim_hidden=256, dim_out=args.num_classes).to(args.device) - else: - exit('Error: unrecognized model') - print(net_glob) + net_glob = get_model(args) net_glob.train() # training + results_save_path = os.path.join(base_dir, 'fed/results.npy') + loss_train = [] - cv_loss, cv_acc = [], [] - val_loss_pre, counter = 0, 0 net_best = None best_loss = None best_acc = None best_epoch = None - val_acc_list, net_list = [], [] lr = args.lr results = [] for iter in range(args.epochs): w_glob = None - loss_locals, grads_local = [], [] + loss_locals = [] m = max(int(args.frac * args.num_users), 1) idxs_users = np.random.choice(range(args.num_users), m, replace=False) print("Round {}, lr: {:.6f}, {}".format(iter, lr, idxs_users)) for idx in idxs_users: - local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) + local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx]) net_local = copy.deepcopy(net_glob) w_local, loss = local.train(net=net_local.to(args.device)) loss_locals.append(copy.deepcopy(loss)) - if not args.grad_norm: - grads = 1.0 - else: - grads = [] - for grad in [param.grad for param in net_local.parameters()]: - if grad is not None: - grads.append(grad.view(-1)) - grads = torch.cat(grads).norm().item() - # print(grads) - grads_local.append(grads) - if w_glob is None: w_glob = copy.deepcopy(w_local) - for k in w_glob.keys(): - w_glob[k] *= grads else: for k in w_glob.keys(): - w_glob[k] += w_local[k] * grads + w_glob[k] += w_local[k] lr *= args.lr_decay # update global weights for k in w_glob.keys(): - w_glob[k] = torch.div(w_glob[k], sum(grads_local)) + w_glob[k] = torch.div(w_glob[k], m) # copy weight to net_glob net_glob.load_state_dict(w_glob) @@ -169,7 +77,7 @@ loss_avg = sum(loss_locals) / len(loss_locals) loss_train.append(loss_avg) - if (iter + 1) % 1 == 0: + if (iter + 1) % args.test_freq == 0: net_glob.eval() acc_test, loss_test = test_img(net_glob, dataset_test, args) print('Round {:3d}, Average loss {:.3f}, Test loss {:.3f}, Test accuracy: {:.2f}'.format( @@ -177,35 +85,13 @@ results.append(np.array([iter, loss_avg, loss_test, acc_test])) final_results = np.array(results) - - results_save_path = './log/{}/fed_{}_{}_iid{}_num{}_C{}_le{}_gn{}.npy'.format( - args.results_save, args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.grad_norm) np.save(results_save_path, final_results) - if args.dataset == 'mnist': - start_saving = 350 - else: - start_saving = 950 - - model_save_path = './save/{}/{}/fed_{}_{}_iid{}_num{}_C{}_le{}_gn{}_iter{}.pt'.format( - args.results_save, args.dataset, args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.grad_norm, iter) + model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter)) if best_acc is None or acc_test > best_acc: best_acc = acc_test best_epoch = iter - if iter > start_saving: + if iter > args.start_saving: torch.save(net_glob.state_dict(), model_save_path) - print('Best model, iter: {}, acc: {}'.format(best_epoch, best_acc)) - - # plot loss curve - plt.figure() - plt.plot(range(len(loss_train)), loss_train) - plt.ylabel('train_loss') - plt.savefig('./log/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) - - # testing - net_glob.eval() - acc_train, loss_train = test_img(net_glob, dataset_train, args) - acc_test, loss_test = test_img(net_glob, dataset_test, args) - print("Training accuracy: {:.2f}".format(acc_train)) - print("Testing accuracy: {:.2f}".format(acc_test)) \ No newline at end of file + print('Best model, iter: {}, acc: {}'.format(best_epoch, best_acc)) \ No newline at end of file diff --git a/main_lg.py b/main_lg.py index 535dace..be975b7 100644 --- a/main_lg.py +++ b/main_lg.py @@ -2,23 +2,16 @@ # -*- coding: utf-8 -*- # Python version: 3.6 -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt import copy import os import itertools import numpy as np -from scipy.stats import mode -from torchvision import datasets, transforms, models import torch from torch import nn -from utils.sampling import mnist_iid, mnist_noniid, cifar10_iid, cifar10_noniid from utils.options import args_parser +from utils.train_utils import get_data, get_model from models.Update import LocalUpdate -from models.Nets import MLP, CNNMnist, CNNCifar, ResnetCifar -from models.Fed import FedAvg from models.test import test_img, test_img_local import pdb @@ -28,93 +21,27 @@ args = args_parser() args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') - trans_mnist = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - - if args.model == 'resnet': - trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.Resize([256,256]), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - trans_cifar_val = transforms.Compose([transforms.Resize([256,256]), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - else: - trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - trans_cifar_val = transforms.Compose([transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - - # load dataset and split users - if args.dataset == 'mnist': - dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist) - dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist) - # sample users - if args.iid: - dict_users_train = mnist_iid(dataset_train, args.num_users) - dict_users_test = mnist_iid(dataset_test, args.num_users) - else: - dict_users_train, rand_set_all = mnist_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=300, train=True) - dict_users_test, _ = mnist_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50, train=False, rand_set_all=rand_set_all) - save_path = './save/{}/randset_lg_{}_iid{}_num{}_C{}.pt'.format( - args.results_save, args.dataset, args.iid, args.num_users, args.frac) - np.save(save_path, rand_set_all) - - elif args.dataset == 'cifar10': - dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar_train) - dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar_val) - if args.iid: - dict_users_train = cifar10_iid(dataset_train, args.num_users) - dict_users_test = cifar10_iid(dataset_test, args.num_users) - else: - dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250, train=True) - dict_users_test, _ = cifar10_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50, train=False, rand_set_all=rand_set_all) - - elif args.dataset == 'cifar100': - dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar_train) - dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar_val) - if args.iid: - dict_users_train = cifar10_iid(dataset_train, args.num_users) - dict_users_test = cifar10_iid(dataset_test, args.num_users) - else: - dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=1000, num_imgs=50, train=True) - dict_users_test, _ = cifar10_noniid(dataset_test, args.num_users, num_shards=1000, num_imgs=10, train=False, rand_set_all=rand_set_all) - else: - exit('Error: unrecognized dataset') - img_size = dataset_train[0][0].shape + base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/{}/'.format( + args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.results_save) + if not os.path.exists(os.path.join(base_dir, 'lg')): + os.makedirs(os.path.join(base_dir, 'lg'), exist_ok=True) + + rand_set_all = [] + if len(args.load_fed) > 0: + rand_save_path = os.path.join(base_dir, 'randset.npy') + rand_set_all= np.load(rand_save_path) + + dataset_train, dataset_test, dict_users_train, dict_users_test, rand_set_all = get_data(args, rand_set_all=rand_set_all) + rand_save_path = os.path.join(base_dir, 'randset.npy') + np.save(rand_save_path, rand_set_all) # build model - if args.model == 'cnn' and args.dataset in ['cifar10', 'cifar100']: - net_glob = CNNCifar(args=args).to(args.device) - elif args.model == 'cnn' and args.dataset == 'mnist': - net_glob = CNNMnist(args=args).to(args.device) - elif args.model == 'resnet' and args.dataset in ['cifar10', 'cifar100']: - net_glob = ResnetCifar(args=args).to(args.device) - elif args.model == 'mlp': - len_in = 1 - for x in img_size: - len_in *= x - net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device) - else: - exit('Error: unrecognized model') - - print(net_glob) + net_glob = get_model(args) net_glob.train() - if args.load_fed: - fed_model_path = './save/keep/fed_{}_{}_iid{}_num{}_C{}_le{}_gn{}.npy'.format( - args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.grad_norm) - if len(args.load_fed_name) > 0: - fed_model_path = './save/keep/{}'.format(args.load_fed_name) - net_glob.load_state_dict(torch.load(fed_model_path)) + + if len(args.load_fed) > 0: + fed_model_path = os.path.join(base_dir, 'fed/{}'.format(args.load_fed)) + net_glob.load_state_dict(torch.load(fed_model_path)) total_num_layers = len(net_glob.weight_keys) w_glob_keys = net_glob.weight_keys[total_num_layers - args.num_layers_keep:] @@ -197,50 +124,21 @@ def test_img_avg_all(): return acc_test_avg, loss_test_avg - if args.local_ep_pretrain > 0: - # pretrain each local model - pretrain_save_path = 'pretrain/{}/{}_{}/user_{}/ep_{}/'.format(args.model, args.dataset, - 'iid' if args.iid else 'noniid', args.num_users, - args.local_ep_pretrain) - if not os.path.exists(pretrain_save_path): - os.makedirs(pretrain_save_path) - - print("\nPretraining local models...") - for idx in range(args.num_users): - net_local = net_local_list[idx] - net_local_path = os.path.join(pretrain_save_path, '{}.pt'.format(idx)) - if os.path.exists(net_local_path): # check if we have a saved model - net_local.load_state_dict(torch.load(net_local_path)) - else: - local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx], pretrain=True) - w_local, loss = local.train(net=net_local.to(args.device)) - print('Local model {}, Train Epoch Loss: {:.4f}'.format(idx, loss)) - torch.save(net_local.state_dict(), net_local_path) - - print("Getting initial loss and acc...") - acc_test_local, loss_test_local = test_img_local_all() - acc_test_avg, loss_test_avg = test_img_avg_all() - loss_test, acc_test = test_img_ensemble_all() - - print('Initial Ensemble: Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}, Loss (ens) {:.3f}, Acc: (ens) {:.2f}, '.format( - loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test)) - # training + results_save_path = os.path.join(base_dir, 'lg/results.npy') + loss_train = [] - cv_loss, cv_acc = [], [] - val_loss_pre, counter = 0, 0 net_best = None best_loss = None best_acc = None best_epoch = None - val_acc_list, net_list = [], [] lr = args.lr results = [] for iter in range(args.epochs): w_glob = {} - loss_locals, grads_local = [], [] + loss_locals = [] m = max(int(args.frac * args.num_users), 1) idxs_users = np.random.choice(range(args.num_users), m, replace=False) # w_keys_epoch = net_glob.state_dict().keys() if (iter + 1) % 25 == 0 else w_glob_keys @@ -258,42 +156,19 @@ def test_img_avg_all(): modules_glob = set([x.split('.')[0] for x in w_keys_epoch]) modules_all = net_local.__dict__['_modules'] - # use grads to calculate a weighted average - if not args.grad_norm: - grads = 1.0 - else: - grads = [] - for key in modules_glob: - module = modules_all[key] - grad = module.weight.grad - if grad is not None: - grads.append(grad.view(-1)) - - try: - grad = module.bias.grad - if grad is not None: - grads.append(grad.view(-1)) - except: - pass - grads = torch.cat(grads).norm().item() - # print(grads) - grads_local.append(grads) - # sum up weights if len(w_glob) == 0: w_glob = copy.deepcopy(net_glob.state_dict()) - for k in w_keys_epoch: # this depends on the layers being named the same (in Nets.py) - w_glob[k] = w_local[k] * grads else: for k in w_keys_epoch: - w_glob[k] += w_local[k] * grads + w_glob[k] += w_local[k] if (iter+1) % int(args.num_users * args.frac): lr *= args.lr_decay # get weighted average for global weights for k in w_keys_epoch: - w_glob[k] = torch.div(w_glob[k], sum(grads_local)) + w_glob[k] = torch.div(w_glob[k], m) # copy weight to the global model (not really necessary) net_glob.load_state_dict(w_glob) @@ -314,22 +189,11 @@ def test_img_avg_all(): acc_test_local, loss_test_local = test_img_local_all() acc_test_avg, loss_test_avg = test_img_avg_all() - # if (iter + 1) % args.test_freq == 0: # this takes too much time, so we run it less frequently - if (iter + 1) > 75: - loss_test, acc_test = test_img_ensemble_all() - print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}, Loss (ens) {:.3f}, Acc: (ens) {:.2f}, '.format( - iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test)) - results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, loss_test, acc_test])) - - else: - print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( - iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) - results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg, np.nan, np.nan])) + print('Round {:3d}, Avg Loss {:.3f}, Loss (local): {:.3f}, Acc (local): {:.2f}, Loss (avg): {:.3}, Acc (avg): {:.2f}'.format( + iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg)) + results.append(np.array([iter, loss_avg, loss_test_local, acc_test_local, loss_test_avg, acc_test_avg])) final_results = np.array(results) - results_save_path = './log/lg_{}_{}_keep{}_iid{}_num{}_C{}_le{}_gn{}_pt{}_load{}_tfreq{}.npy'.format( - args.dataset, args.model, args.num_layers_keep, args.iid, args.num_users, args.frac, - args.local_ep, args.grad_norm, args.local_ep_pretrain, args.load_fed, args.test_freq) np.save(results_save_path, final_results) if best_acc is None or acc_test_local > best_acc: @@ -337,35 +201,14 @@ def test_img_avg_all(): best_epoch = iter for user in range(args.num_users): - model_save_path = './save/{}/{}_{}_iid{}/'.format( - args.results_save, args.dataset, 0, args.iid) - if not os.path.exists(model_save_path): - os.makedirs(model_save_path) - - model_save_path = './save/{}/{}_{}_iid{}/user{}.pt'.format( - args.results_save, args.dataset, 0, args.iid, user) + model_save_path = os.path.join(base_dir, 'lg/model_user{}.pt'.format(user)) torch.save(net_local_list[user].state_dict(), model_save_path) for user in range(args.num_users): - model_save_path = './save/{}/{}_{}_iid{}/user{}.pt'.format( - args.results_save, args.dataset, 0, args.iid, user) + model_save_path = os.path.join(base_dir, 'lg/model_user{}.pt'.format(user)) net_local = net_local_list[idx] net_local.load_state_dict(torch.load(model_save_path)) loss_test, acc_test = test_img_ensemble_all() - print('Best model, iter: {}, acc: {}, acc (ens): {}'.format(best_epoch, best_acc, acc_test)) - - - # plot loss curve - plt.figure() - plt.plot(range(len(loss_train)), loss_train) - plt.ylabel('train_loss') - plt.savefig('./log/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) - - # testing - net_glob.eval() - acc_train, loss_train = test_img(net_glob, dataset_train, args) - acc_test, loss_test = test_img(net_glob, dataset_test, args) - print("Training accuracy: {:.2f}".format(acc_train)) - print("Testing accuracy: {:.2f}".format(acc_test)) \ No newline at end of file + print('Best model, iter: {}, acc: {}, acc (ens): {}'.format(best_epoch, best_acc, acc_test)) \ No newline at end of file diff --git a/models/Nets.py b/models/Nets.py index 7b6b2b5..602bb38 100755 --- a/models/Nets.py +++ b/models/Nets.py @@ -101,27 +101,4 @@ def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) - return F.log_softmax(x, dim=1) - - -class ResnetCifar(nn.Module): - def __init__(self, args): - super(ResnetCifar, self).__init__() - self.extractor = models.resnet18(pretrained=False) - self.fflayer = nn.Sequential(nn.Linear(1000, args.num_classes)) - - def forward(self, x): - x = self.extractor(x) - x = self.fflayer(x) - return F.log_softmax(x, dim=1) - -class ResnetCifar(nn.Module): - def __init__(self, args): - super(ResnetCifar, self).__init__() - self.extractor = models.resnet18(pretrained=False) - self.fflayer = nn.Sequential(nn.Linear(1000, args.num_classes)) - - def forward(self, x): - x = self.extractor(x) - x = self.fflayer(x) return F.log_softmax(x, dim=1) \ No newline at end of file diff --git a/utils/options.py b/utils/options.py index 9d83ffc..292cf7e 100644 --- a/utils/options.py +++ b/utils/options.py @@ -42,9 +42,9 @@ def args_parser(): parser.add_argument('--print_freq', type=int, default=100, help="print loss frequency during training") parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') parser.add_argument('--test_freq', type=int, default=1, help='how often to test on val set') - parser.add_argument('--load_fed', action='store_true', help='load pretrained federated model for local_global') - parser.add_argument('--load_fed_name', type=str, default='', help='define pretrained federated model path') + parser.add_argument('--load_fed', type=str, default='', help='define pretrained federated model path') parser.add_argument('--results_save', type=str, default='/', help='define fed results save folder') + parser.add_argument('--start_saving', type=int, default=0, help='when to start saving models') args = parser.parse_args() return args diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100644 index 0000000..2c227d3 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,56 @@ +from torchvision import datasets, transforms +from models.Nets import MLP, CNNMnist, CNNCifar +from utils.sampling import mnist_iid, mnist_noniid, cifar10_iid, cifar10_noniid + +def get_data(args, rand_set_all=[]): + trans_mnist = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))]) + trans_cifar_train = transforms.Compose([transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + trans_cifar_val = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + + if args.dataset == 'mnist': + dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist) + dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist) + # sample users + if args.iid: + dict_users_train = mnist_iid(dataset_train, args.num_users) + dict_users_test = mnist_iid(dataset_test, args.num_users) + else: + dict_users_train, rand_set_all = mnist_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=300, + train=True, rand_set_all=rand_set_all) + dict_users_test, rand_set_all = mnist_noniid(dataset_test, args.num_users, num_shards=200, num_imgs=50, + train=False, rand_set_all=rand_set_all) + elif args.dataset == 'cifar10': + dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar_train) + dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar_val) + if args.iid: + dict_users_train = cifar10_iid(dataset_train, args.num_users) + dict_users_test = cifar10_iid(dataset_test, args.num_users) + else: + dict_users_train, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250, + train=True, rand_set_all=rand_set_all) + dict_users_test, rand_set_all = cifar10_noniid(dataset_train, args.num_users, num_shards=200, num_imgs=250, + train=True, rand_set_all=rand_set_all) + else: + exit('Error: unrecognized dataset') + + return dataset_train, dataset_test, dict_users_train, dict_users_test, rand_set_all + +def get_model(args): + if args.model == 'cnn' and args.dataset in ['cifar10', 'cifar100']: + net_glob = CNNCifar(args=args).to(args.device) + elif args.model == 'cnn' and args.dataset == 'mnist': + net_glob = CNNMnist(args=args).to(args.device) + elif args.model == 'mlp' and args.dataset == 'mnist': + net_glob = MLP(dim_in=784, dim_hidden=256, dim_out=args.num_classes).to(args.device) + else: + exit('Error: unrecognized model') + print(net_glob) + + return net_glob \ No newline at end of file