From b6db67e21daff96a26e5575241599d5e1fa00b58 Mon Sep 17 00:00:00 2001 From: tayebiarasteh Date: Fri, 1 Apr 2022 17:18:55 +0200 Subject: [PATCH] federated added --- Train_Valid_brats.py | 284 +++++++++++++++++++++++++++--------- config/config.yaml | 12 +- data/data_provider_brats.py | 5 +- main_3D_brats.py | 82 ++++++++++- models/EDiceLoss_loss.py | 10 +- 5 files changed, 310 insertions(+), 83 deletions(-) diff --git a/Train_Valid_brats.py b/Train_Valid_brats.py index c9dda13..7c7c0e4 100644 --- a/Train_Valid_brats.py +++ b/Train_Valid_brats.py @@ -14,6 +14,7 @@ from tqdm import tqdm import torchmetrics import torch.nn.functional as F +import syft as sy from config.serde import read_config, write_config from data.augmentation_brats import random_augment @@ -21,6 +22,7 @@ import warnings warnings.filterwarnings('ignore') epsilon = 1e-15 +hook = sy.TorchHook(torch) @@ -47,7 +49,6 @@ def __init__(self, cfg_path, num_epochs=10, resume=False, augment=False): if resume == False: self.model_info = self.params['Network'] self.epoch = 0 - self.step = 0 self.best_loss = float('inf') self.setup_cuda() self.writer = SummaryWriter(log_dir=os.path.join(self.params['target_dir'], self.params['tb_logs_path'])) @@ -177,10 +178,9 @@ def load_checkpoint(self, model, optimiser, loss_function, weight=None): self.model.load_state_dict(checkpoint['model_state_dict']) self.optimiser.load_state_dict(checkpoint['optimizer_state_dict']) self.epoch = checkpoint['epoch'] - self.step = checkpoint['step'] self.best_loss = checkpoint['best_loss'] self.writer = SummaryWriter(log_dir=os.path.join(os.path.join( - self.params['target_dir'], self.params['tb_logs_path'])), purge_step=self.step + 1) + self.params['target_dir'], self.params['tb_logs_path'])), purge_step=self.epoch + 1) @@ -188,16 +188,14 @@ def train_epoch(self, train_loader, valid_loader=None): """Training epoch """ self.params = read_config(self.cfg_path) + total_start_time = time.time() for epoch in range(self.num_epochs - self.epoch): self.epoch += 1 # initializing the loss list batch_loss = 0 - batch_count = 0 - start_time = time.time() - total_start_time = time.time() for idx, (image, label) in enumerate(train_loader): self.model.train() @@ -224,43 +222,199 @@ def train_epoch(self, train_loader, valid_loader=None): self.optimiser.step() batch_loss += loss.item() - batch_count += 1 - self.step += 1 - - # Prints train loss after number of steps specified. - if (self.step) % self.params['display_train_loss_freq'] == 0: - end_time = time.time() - iteration_hours, iteration_mins, iteration_secs = self.time_duration(start_time, end_time) - total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) - train_loss = batch_loss / batch_count - batch_loss = 0 - batch_count = 0 - start_time = time.time() - - print('Step {} | train epoch {} | batch {} / {} | loss: {:.3f}'. - format(self.step, self.epoch, idx+1, len(train_loader), train_loss), - f'\ntime: {iteration_hours}h {iteration_mins}m {iteration_secs}s', - f'| total: {total_hours}h {total_mins}m {total_secs}s\n') - self.writer.add_scalar('Train_loss', train_loss, self.step) - - # Validation iteration & calculate metrics - if (self.step) % (self.params['display_stats_freq']) == 0: - - # saving the model, checkpoint, TensorBoard, etc. - if not valid_loader == None: - valid_loss, valid_F1, valid_accuracy, valid_specifity, valid_sensitivity, valid_precision = self.valid_epoch(valid_loader) - end_time = time.time() - total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) - - self.calculate_tb_stats(valid_loss=valid_loss, valid_F1=valid_F1, valid_accuracy=valid_accuracy, valid_specifity=valid_specifity, - valid_sensitivity=valid_sensitivity, valid_precision=valid_precision) - self.savings_prints(iteration_hours, iteration_mins, iteration_secs, total_hours, - total_mins, total_secs, train_loss, valid_loss=valid_loss, - valid_F1=valid_F1, valid_accuracy=valid_accuracy, valid_specifity= valid_specifity, + + # Prints train loss after number of steps specified. + train_loss = batch_loss / len(train_loader) + self.writer.add_scalar('Train_loss', train_loss, self.epoch) + + # Validation iteration & calculate metrics + if (self.epoch) % (self.params['display_stats_freq']) == 0: + + # saving the model, checkpoint, TensorBoard, etc. + if not valid_loader == None: + valid_loss, valid_F1, valid_accuracy, valid_specifity, valid_sensitivity, valid_precision = self.valid_epoch(valid_loader) + end_time = time.time() + iteration_hours, iteration_mins, iteration_secs = self.time_duration(start_time, end_time) + total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) + + self.calculate_tb_stats(valid_loss=valid_loss, valid_F1=valid_F1, valid_accuracy=valid_accuracy, valid_specifity=valid_specifity, valid_sensitivity=valid_sensitivity, valid_precision=valid_precision) - else: - self.savings_prints(iteration_hours, iteration_mins, iteration_secs, total_hours, - total_mins, total_secs, train_loss) + self.savings_prints(iteration_hours, iteration_mins, iteration_secs, total_hours, + total_mins, total_secs, train_loss, valid_loss=valid_loss, + valid_F1=valid_F1, valid_accuracy=valid_accuracy, valid_specifity= valid_specifity, + valid_sensitivity=valid_sensitivity, valid_precision=valid_precision) + else: + end_time = time.time() + iteration_hours, iteration_mins, iteration_secs = self.time_duration(start_time, end_time) + total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) + self.savings_prints(iteration_hours, iteration_mins, iteration_secs, total_hours, + total_mins, total_secs, train_loss) + + + + def training_setup_federated(self, train_loader, valid_loader=None): + """ + """ + self.params = read_config(self.cfg_path) + + print('preparing the workers and their data...\n') + # create a couple workers + client1 = sy.VirtualWorker(hook, id="client1") + client2 = sy.VirtualWorker(hook, id="client2") + client3 = sy.VirtualWorker(hook, id="client3") + secure_worker = sy.VirtualWorker(hook, id="secure_worker") + + client1.clear_objects() + client2.clear_objects() + client3.clear_objects() + secure_worker.clear_objects() + + client1.add_workers([client2, client3, secure_worker]) + client2.add_workers([client1, client3, secure_worker]) + client3.add_workers([client1, client2, secure_worker]) + secure_worker.add_workers([client1, client2, client3]) + + # train_loader_client1 = [] + # for batch_idx, (data, target) in enumerate(tqdm(train_loader[0])): + # data = data.send(client1) + # target = target.send(client1) + # train_loader_client1.append((data, target)) + # print('\nclient 1 done!') + # + # train_loader_client2 = [] + # for batch_idx, (data, target) in enumerate(tqdm(train_loader[2])): + # data = data.send(client2) + # target = target.send(client2) + # train_loader_client2.append((data, target)) + # print('\nclient 2 done!') + # + # train_loader_client3 = [] + # for batch_idx, (data, target) in enumerate(train_loader[2]): + # data = data.send(client3) + # target = target.send(client3) + # train_loader_client3.append((data, target)) + + print('\nall done!') + + total_start_time = time.time() + for epoch in range(self.num_epochs - self.epoch): + self.epoch += 1 + + start_time = time.time() + self.model.train() + + model_client1 = self.model.copy().send(client1) + model_client2 = self.model.copy().send(client2) + model_client3 = self.model.copy().send(client3) + + optimizer_client1 = torch.optim.Adam(model_client1.parameters(), lr=float(self.params['Network']['lr']), + weight_decay=float(self.params['Network']['weight_decay']), + amsgrad=self.params['Network']['amsgrad']) + optimizer_client2 = torch.optim.Adam(model_client2.parameters(), lr=float(self.params['Network']['lr']), + weight_decay=float(self.params['Network']['weight_decay']), + amsgrad=self.params['Network']['amsgrad']) + optimizer_client3 = torch.optim.Adam(model_client3.parameters(), lr=float(self.params['Network']['lr']), + weight_decay=float(self.params['Network']['weight_decay']), + amsgrad=self.params['Network']['amsgrad']) + + model_client1, loss_client1 = self.train_epoch_federated(train_loader[0], optimizer_client1, model_client1) + model_client2, loss_client2 = self.train_epoch_federated(train_loader[1], optimizer_client2, model_client2) + model_client3, loss_client3 = self.train_epoch_federated(train_loader[2], optimizer_client3, model_client3) + + model_client1.move(secure_worker) + model_client2.move(secure_worker) + model_client3.move(secure_worker) + + print('before:', self.model.output_block.conv_out.weight.data.sum().item()) + + with torch.no_grad(): + for param, param1, param2, param3 in zip(self.model.parameters(), model_client1.parameters(), + model_client2.parameters(), model_client3.parameters()): + param.data = ((param1.data + param2.data + param3.data) / 3).get() + + print('after:', self.model.output_block.conv_out.weight.data.sum().item()) + + # train loss just as an average of client losses + train_loss = (loss_client1 + loss_client2 + loss_client3) / 3 + + # Prints train loss after number of steps specified. + end_time = time.time() + iteration_hours, iteration_mins, iteration_secs = self.time_duration(start_time, end_time) + total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) + + print('train epoch {} | loss client1: {:.3f} | loss client2: {:.3f} | loss client3: {:.3f}'. + format(self.epoch, loss_client1, loss_client2, loss_client3), + f'\ntime: {iteration_hours}h {iteration_mins}m {iteration_secs}s', + f'| total: {total_hours}h {total_mins}m {total_secs}s\n') + self.writer.add_scalar('Train_loss_client1', loss_client1, self.epoch) + self.writer.add_scalar('Train_loss_client2', loss_client2, self.epoch) + self.writer.add_scalar('Train_loss_client3', loss_client3, self.epoch) + + # Validation iteration & calculate metrics + if (self.epoch) % (self.params['display_stats_freq']) == 0: + + # saving the model, checkpoint, TensorBoard, etc. + if not valid_loader == None: + valid_loss, valid_F1, valid_accuracy, valid_specifity, valid_sensitivity, valid_precision = self.valid_epoch(valid_loader) + end_time = time.time() + iteration_hours, iteration_mins, iteration_secs = self.time_duration(start_time, end_time) + total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) + + self.calculate_tb_stats(valid_loss=valid_loss, valid_F1=valid_F1, valid_accuracy=valid_accuracy, + valid_specifity=valid_specifity, valid_sensitivity=valid_sensitivity, valid_precision=valid_precision) + self.savings_prints(iteration_hours, iteration_mins, iteration_secs, total_hours, total_mins, + total_secs, train_loss, valid_loss=valid_loss, valid_F1=valid_F1, + valid_accuracy=valid_accuracy, valid_specifity=valid_specifity, + valid_sensitivity=valid_sensitivity, valid_precision=valid_precision) + else: + end_time = time.time() + iteration_hours, iteration_mins, iteration_secs = self.time_duration(start_time, end_time) + total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) + self.savings_prints(iteration_hours, iteration_mins, iteration_secs, total_hours, + total_mins, total_secs, train_loss) + + + + def train_epoch_federated(self, train_loader, optimizer, model): + """Training epoch + """ + + batch_loss = 0 + model.train() + + # training epoch of a client + for idx, (image, label) in enumerate(train_loader): + + # if we would like to have data augmentation during training + if self.augment: + image, label = random_augment(image, label, self.cfg_path) + + loc = model.location + image = image.send(loc) + label = label.send(loc) + + label = label.long() + image = image.float() + image = image.to(self.device) + label = label.to(self.device) + + optimizer.zero_grad() + + with torch.set_grad_enabled(True): + + output = model(image) + loss_client = self.loss_function(output, label) # for dice loss + + loss_client.backward() + optimizer.step() + + batch_loss += loss_client + + batch_loss = batch_loss.get().data + avg_loss = batch_loss / len(train_loader) + + return model, avg_loss.item() + @@ -377,7 +531,6 @@ def savings_prints(self, iteration_hours, iteration_mins, iteration_secs, total_ # Saves information about training to config file self.params['Network']['num_epoch'] = self.epoch - self.params['Network']['step'] = self.step write_config(self.params, self.cfg_path, sort_keys=True) # Saving the model based on the best loss @@ -392,24 +545,23 @@ def savings_prints(self, iteration_hours, iteration_mins, iteration_secs, total_ torch.save(self.model.state_dict(), os.path.join(self.params['target_dir'], self.params['network_output_path'], self.params['trained_model_name'])) - # Saving every couple of iterations - if (self.step) % self.params['network_save_freq'] == 0: + # Saving every couple of epochs + if (self.epoch) % self.params['network_save_freq'] == 0: torch.save(self.model.state_dict(), os.path.join(self.params['target_dir'], self.params['network_output_path'], - 'step{}_'.format(self.step) + self.params['trained_model_name'])) + 'epoch{}_'.format(self.epoch) + self.params['trained_model_name'])) - # Save a checkpoint every iteration - if (self.step) % self.params['network_checkpoint_freq'] == 0: - torch.save({'epoch': self.epoch, 'step': self.step, - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimiser.state_dict(), - 'loss_state_dict': self.loss_function.state_dict(), 'num_epochs': self.num_epochs, - 'model_info': self.model_info, 'best_loss': self.best_loss}, - os.path.join(self.params['target_dir'], self.params['network_output_path'], self.params['checkpoint_name'])) + # Save a checkpoint every epoch + torch.save({'epoch': self.epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimiser.state_dict(), + 'loss_state_dict': self.loss_function.state_dict(), 'num_epochs': self.num_epochs, + 'model_info': self.model_info, 'best_loss': self.best_loss}, + os.path.join(self.params['target_dir'], self.params['network_output_path'], self.params['checkpoint_name'])) print('------------------------------------------------------' '----------------------------------') - print(f'Step: {self.step} (epoch: {self.epoch}) | ' - f'iteration time: {iteration_hours}h {iteration_mins}m {iteration_secs}s | ' + print(f'epoch: {self.epoch} | ' + f'epoch time: {iteration_hours}h {iteration_mins}m {iteration_secs}s | ' f'total time: {total_hours}h {total_mins}m {total_secs}s') print(f'\n\tTrain loss: {train_loss:.4f}') @@ -425,7 +577,7 @@ def savings_prints(self, iteration_hours, iteration_mins, iteration_secs, total_ # saving the training and validation stats msg = f'----------------------------------------------------------------------------------------\n' \ - f'Step: {self.step} (epoch: {self.epoch}) | Iteration Time: {iteration_hours}h {iteration_mins}m {iteration_secs}s' \ + f'epoch: {self.epoch} | epoch Time: {iteration_hours}h {iteration_mins}m {iteration_secs}s' \ f' | total time: {total_hours}h {total_mins}m {total_secs}s\n\n\tTrain loss: {train_loss:.4f} | ' \ f'Val. loss: {valid_loss:.4f} | F1 (Dice score): {valid_F1.mean().item() * 100:.2f}% | accuracy: {valid_accuracy.mean().item() * 100:.2f}% ' \ f' | specifity: {valid_specifity.mean().item() * 100:.2f}%' \ @@ -433,7 +585,7 @@ def savings_prints(self, iteration_hours, iteration_mins, iteration_secs, total_ f' | F1 class 1: {valid_F1[0].item() * 100:.2f}% | F1 class 2: {valid_F1[1].item() * 100:.2f}% | F1 class 3: {valid_F1[2].item() * 100:.2f}%\n\n' else: msg = f'----------------------------------------------------------------------------------------\n' \ - f'Step: {self.step} (epoch: {self.epoch}) | iteration time: {iteration_hours}h {iteration_mins}m {iteration_secs}s' \ + f'epoch: {self.epoch} | epoch time: {iteration_hours}h {iteration_mins}m {iteration_secs}s' \ f' | total time: {total_hours}h {total_mins}m {total_secs}s\n\n\ttrain loss: {train_loss:.4f}\n\n' with open(os.path.join(self.params['target_dir'], self.params['stat_log_path']) + '/Stats', 'a') as f: f.write(msg) @@ -458,11 +610,11 @@ def calculate_tb_stats(self, valid_loss=None, valid_F1=None, valid_accuracy=None validation loss of the model """ if valid_loss is not None: - self.writer.add_scalar('Valid_loss', valid_loss, self.step) - self.writer.add_scalar('Valid_specifity', valid_specifity.mean().item(), self.step) - self.writer.add_scalar('Valid_F1', valid_F1.mean().item(), self.step) - self.writer.add_scalar('Valid_F1 class 1', valid_F1[0].item(), self.step) - self.writer.add_scalar('Valid_F1 class 2', valid_F1[1].item(), self.step) - self.writer.add_scalar('Valid_F1 class 3', valid_F1[2].item(), self.step) - self.writer.add_scalar('Valid_precision', valid_precision.mean().item(), self.step) - self.writer.add_scalar('Valid_recall_sensitivity', valid_sensitivity.mean().item(), self.step) \ No newline at end of file + self.writer.add_scalar('Valid_loss', valid_loss, self.epoch) + self.writer.add_scalar('Valid_specifity', valid_specifity.mean().item(), self.epoch) + self.writer.add_scalar('Valid_F1', valid_F1.mean().item(), self.epoch) + self.writer.add_scalar('Valid_F1 class 1', valid_F1[0].item(), self.epoch) + self.writer.add_scalar('Valid_F1 class 2', valid_F1[1].item(), self.epoch) + self.writer.add_scalar('Valid_F1 class 3', valid_F1[2].item(), self.epoch) + self.writer.add_scalar('Valid_precision', valid_precision.mean().item(), self.epoch) + self.writer.add_scalar('Valid_recall_sensitivity', valid_sensitivity.mean().item(), self.epoch) \ No newline at end of file diff --git a/config/config.yaml b/config/config.yaml index a50b74e..bf51f1f 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -13,7 +13,7 @@ augmentation: general_intensity_probability: 0.5 # probability of having intensity data augmentation at all flip_prob: 0.2 # probability of having flipping augmentation zoom_range: [1, 1.15] # 0.1 means from 0.9 to 1.1 - zoom_prob: 0.2 # probability of having zooming augmentation [for the moment only zoom in: more than 1] + zoom_prob: -1 # probability of having zooming augmentation [for the moment only zoom in: more than 1] rotation_range: 5 # degrees. 1 = (-1, 1) rotation_prob: -1 # probability of having rotation augmentation [DON'T USE FOR BRATS] shift_range: 1 # pixels. 1 = (-1, 1) @@ -29,7 +29,7 @@ augmentation: motion_prob: -1 # probability of having random motion augmentation [DON'T USE FOR BRATS] ghosting_prob: -1 # probability of having ghosting augmentation blurring_range: [0, 1] # range of std of Gaussian filter; don't go more than 1 - blurring_prob: 0.2 # probability of having blurring augmentation + blurring_prob: -1 # probability of having blurring augmentation network_output_path: network_data/ output_data_path: output_data/ tb_logs_path: tensor_board_logs/ @@ -41,8 +41,6 @@ trained_model_name: trained_model.pth # changeable items: file_path: /home/soroosh/Documents/datasets/BraTS20/old_BraTS20/cropped/ target_dir: /home/soroosh/Documents/Repositories_target_files/federated_he/ -network_save_freq: 600 # based on the batch size, shows the number of batches done; every half an hour is good -display_stats_freq: 200 # valid freq is equal to this -display_train_loss_freq: 20 -network_checkpoint_freq: 1 -num_epochs: 250 \ No newline at end of file +network_save_freq: 20 # based on the batch size, shows the number of epochs done; every half an hour is good +display_stats_freq: 1 # valid freq is equal to this +num_epochs: 1000 \ No newline at end of file diff --git a/data/data_provider_brats.py b/data/data_provider_brats.py index 6844bb8..27ae5a4 100644 --- a/data/data_provider_brats.py +++ b/data/data_provider_brats.py @@ -28,7 +28,7 @@ class data_loader_3D(Dataset): """ This is the pipeline based on Pytorch's Dataset and Dataloader """ - def __init__(self, cfg_path, mode='train', modality=2, multimodal=True): + def __init__(self, cfg_path, mode='train', modality=2, multimodal=True, site=None): """ Parameters ---------- @@ -62,10 +62,13 @@ def __init__(self, cfg_path, mode='train', modality=2, multimodal=True): elif mode == 'test': self.subset_df = org_df[org_df['soroosh_split'] == 'test'] + if not site == None: + self.subset_df = self.subset_df[self.subset_df['site'] == site] self.file_path_list = list(self.subset_df['pat_num']) + def __len__(self): """Returns the length of the dataset""" return len(self.file_path_list) diff --git a/main_3D_brats.py b/main_3D_brats.py index f4e76f8..4ea7d17 100644 --- a/main_3D_brats.py +++ b/main_3D_brats.py @@ -89,6 +89,76 @@ def main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/feder trainer.train_epoch(train_loader=train_loader, valid_loader=valid_loader) +def main_train_federated_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", valid=False, + resume=False, augment=False, experiment_name='name', modality=2): + """Main function for training + validation for directly 3d-wise + + Parameters + ---------- + global_config_path: str + always global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml" + + valid: bool + if we want to do validation + + resume: bool + if we are resuming training on a model + + augment: bool + if we want to have data augmentation during training + + experiment_name: str + name of the experiment, in case of resuming training. + name of new experiment, in case of new training. + + modality: int + modality of the MR sequence + 1: T1 + 2: T1Gd + 3: T2 + 4: T2-FLAIR + """ + if resume == True: + params = open_experiment(experiment_name, global_config_path) + else: + params = create_experiment(experiment_name, global_config_path) + cfg_path = params["cfg_path"] + + # Changeable network parameters + model = UNet3D(n_out_classes=3) # for multi label + + loss_function = EDiceLoss # for multi label + optimizer = torch.optim.Adam(model.parameters(), lr=float(params['Network']['lr']), + weight_decay=float(params['Network']['weight_decay']), amsgrad=params['Network']['amsgrad']) + + train_dataset_client1 = data_loader_3D(cfg_path=cfg_path, mode='train', site='site-1') + train_loader_client1 = torch.utils.data.DataLoader(dataset=train_dataset_client1, batch_size=params['Network']['batch_size'], + pin_memory=True, drop_last=True, shuffle=False, num_workers=4) + train_dataset_client2 = data_loader_3D(cfg_path=cfg_path, mode='train', site='site-2') + train_loader_client2 = torch.utils.data.DataLoader(dataset=train_dataset_client2, batch_size=params['Network']['batch_size'], + pin_memory=True, drop_last=True, shuffle=False, num_workers=4) + train_dataset_client3 = data_loader_3D(cfg_path=cfg_path, mode='train', site='site-3') + train_loader_client3 = torch.utils.data.DataLoader(dataset=train_dataset_client3, batch_size=params['Network']['batch_size'], + pin_memory=True, drop_last=True, shuffle=False, num_workers=4) + + + if valid: + valid_dataset = data_loader_3D(cfg_path=cfg_path, mode='valid') + valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=params['Network']['batch_size'], + pin_memory=True, drop_last=True, shuffle=False, num_workers=4) + else: + valid_loader = None + + train_loader = [train_loader_client1, train_loader_client2, train_loader_client3] + trainer = Training(cfg_path, num_epochs=params['num_epochs'], resume=resume, augment=augment) + if resume == True: + trainer.load_checkpoint(model=model, optimiser=optimizer, loss_function=loss_function, weight=None) + else: + trainer.setup_model(model=model, optimiser=optimizer, + loss_function=loss_function, weight=None) + trainer.training_setup_federated(train_loader=train_loader, valid_loader=valid_loader) + + def main_test_federated_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", @@ -251,10 +321,12 @@ def main_predict_3D(global_config_path="/home/soroosh/Documents/Repositories/fed if __name__ == '__main__': - # delete_experiment(experiment_name='4levelunet24_flip_gamma_AWGN_zoomin_central_full_lr1e4_80_80_80', global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml") - main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", - valid=True, resume=False, augment=True, experiment_name='4levelunet24_flip_gamma_AWGN_zoomin_central_full_lr1e4_80_80_80') + # delete_experiment(experiment_name='federatedtest', global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml") + # main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", + # valid=True, resume=False, augment=True, experiment_name='4levelunet24_flip_gamma_AWGN_central_full_lr1e4_80_80_80') + main_train_federated_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", + valid=True, resume=False, augment=True, experiment_name='federated_full_3client_4levelunet24_flip_gamma_AWGN_lr1e4_80_80_80') # main_evaluate_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", - # experiment_name='multimod_24_4level_flip_gamma_AWGN_blur_cropped_minmax_central_full_lr1e4_80_80_100', tta=False) + # experiment_name='4levelunet24_flip_gamma_AWGN_blur_zoomin_central_full_lr1e4_80_80_80', tta=False) # main_predict_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", - # experiment_name='4levelunet24_flip_gamma_AWGN_zoomin_central_full_lr3e4_80_80_100', tta=False) + # experiment_name='4levelunet24_flip_gamma_AWGN_blur_zoomin_central_full_lr1e4_80_80_80', tta=False) diff --git a/models/EDiceLoss_loss.py b/models/EDiceLoss_loss.py index 0098181..7c01563 100644 --- a/models/EDiceLoss_loss.py +++ b/models/EDiceLoss_loss.py @@ -6,6 +6,8 @@ Dice loss (multi label) tailored to Brats needs. """ +import pdb + import torch import torch.nn as nn @@ -51,17 +53,17 @@ def compute_intersection(inputs, targets): def forward(self, inputs, target): dice = 0 - for i in range(target.size(1)): + for i in range(target.shape[1]): dice = dice + self.binary_dice(inputs[:, i, ...], target[:, i, ...], i) - final_dice = dice / target.size(1) + final_dice = dice / target.shape[1] return final_dice def metric(self, inputs, target): dices = [] - for j in range(target.size(0)): + for j in range(target.shape[0]): dice = [] - for i in range(target.size(1)): + for i in range(target.shape[1]): dice.append(self.binary_dice(inputs[j, i], target[j, i], i, True)) dices.append(dice) return dices