From c81b1def05eab5f12a0f9a64df545c783b6f255e Mon Sep 17 00:00:00 2001 From: tayebiarasteh Date: Thu, 31 Mar 2022 20:44:29 +0200 Subject: [PATCH] initial commit --- .gitignore | 5 + Prediction_brats.py | 373 ++++++++++++++++++++++++ Train_Valid_brats.py | 468 ++++++++++++++++++++++++++++++ config/config.yaml | 48 +++ config/serde.py | 77 +++++ data/augmentation_brats.py | 146 ++++++++++ data/csv_data_preprocess_brats.py | 285 ++++++++++++++++++ data/data_provider_brats.py | 414 ++++++++++++++++++++++++++ data/data_utils.py | 147 ++++++++++ main_3D_brats.py | 260 +++++++++++++++++ models/Diceloss.py | 88 ++++++ models/EDiceLoss_loss.py | 83 ++++++ models/UNet3D.py | 229 +++++++++++++++ 13 files changed, 2623 insertions(+) create mode 100644 .gitignore create mode 100644 Prediction_brats.py create mode 100644 Train_Valid_brats.py create mode 100644 config/config.yaml create mode 100644 config/serde.py create mode 100644 data/augmentation_brats.py create mode 100755 data/csv_data_preprocess_brats.py create mode 100644 data/data_provider_brats.py create mode 100644 data/data_utils.py create mode 100644 main_3D_brats.py create mode 100644 models/Diceloss.py create mode 100644 models/EDiceLoss_loss.py create mode 100644 models/UNet3D.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..030f0ba --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/data/__pycache__/ +/__pycache__/ +/.idea/ +/config/__pycache__/ +/models/__pycache__/ diff --git a/Prediction_brats.py b/Prediction_brats.py new file mode 100644 index 0000000..c4675e8 --- /dev/null +++ b/Prediction_brats.py @@ -0,0 +1,373 @@ +""" +Created on March 8, 2022. +Prediction_brats.py + +@author: Soroosh Tayebi Arasteh +https://github.com/tayebiarasteh/ +""" + +import pdb +import torch +import os.path +import numpy as np +import torchmetrics +from tqdm import tqdm +import torch.nn.functional as F +import torchio as tio +import nibabel as nib + +from config.serde import read_config + +epsilon = 1e-15 + + + +class Prediction: + def __init__(self, cfg_path): + """ + This class represents prediction (testing) process similar to the Training class. + """ + self.params = read_config(cfg_path) + self.cfg_path = cfg_path + self.setup_cuda() + + + def setup_cuda(self, cuda_device_id=0): + """setup the device. + Parameters + ---------- + cuda_device_id: int + cuda device id + """ + if torch.cuda.is_available(): + torch.backends.cudnn.fastest = True + torch.cuda.set_device(cuda_device_id) + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + + def setup_model(self, model, model_file_name=None): + if model_file_name == None: + model_file_name = self.params['trained_model_name'] + self.model = model.to(self.device) + + self.model.load_state_dict(torch.load(os.path.join(self.params['target_dir'], self.params['network_output_path'], model_file_name))) + # self.model.load_state_dict(torch.load(os.path.join(self.params['target_dir'], self.params['network_output_path']) + "step2400_" + model_file_name)) + + + + def setup_model_federated(self, model, model_file_name=None): + if model_file_name == None: + model_file_name = self.params['trained_model_name'] + self.model = model.to(self.device) + + state_dict = torch.load(os.path.join(self.params['target_dir'], self.params['network_output_path'], model_file_name)) + self.model.load_state_dict(state_dict['model']) + # self.model.load_state_dict(state_dict) + + + + def evaluate_3D(self, test_loader): + """Evaluation with metrics epoch + Returns + ------- + epoch_f1_score: float + average test F1 score + average_specifity: float + average test specifity + average_sensitivity: float + average test sensitivity + average_precision: float + average test precision + """ + self.model.eval() + total_f1_score = [] + total_accuracy = [] + total_specifity_score = [] + total_sensitivity_score = [] + total_precision_score = [] + + for idx, (image, label) in enumerate(tqdm(test_loader)): + label = label.long() + image = image.float() + image = image.to(self.device) + label = label.to(self.device) + + with torch.no_grad(): + output = self.model(image) + output_sigmoided = F.sigmoid(output.permute(0, 2, 3, 4, 1)) + output_sigmoided = (output_sigmoided > 0.5).float() + + ############ Evaluation metric calculation ######## + # Metrics calculation (macro) over the whole set + confusioner = torchmetrics.ConfusionMatrix(num_classes=label.shape[1], multilabel=True).to(self.device) + confusion = confusioner(output_sigmoided.flatten(start_dim=0, end_dim=3), + label.permute(0, 2, 3, 4, 1).flatten(start_dim=0, end_dim=3)) + + F1_disease = [] + accuracy_disease = [] + specifity_disease = [] + sensitivity_disease = [] + precision_disease = [] + + for idx, disease in enumerate(confusion): + TN = disease[0, 0] + FP = disease[0, 1] + FN = disease[1, 0] + TP = disease[1, 1] + F1_disease.append(2 * TP / (2 * TP + FN + FP + epsilon)) + accuracy_disease.append((TP + TN) / (TP + TN + FP + FN + epsilon)) + specifity_disease.append(TN / (TN + FP + epsilon)) + sensitivity_disease.append(TP / (TP + FN + epsilon)) + precision_disease.append(TP / (TP + FP + epsilon)) + + # Macro averaging + total_f1_score.append(torch.stack(F1_disease)) + total_accuracy.append(torch.stack(accuracy_disease)) + total_specifity_score.append(torch.stack(specifity_disease)) + total_sensitivity_score.append(torch.stack(sensitivity_disease)) + total_precision_score.append(torch.stack(precision_disease)) + + average_f1_score = torch.stack(total_f1_score).mean(0) + average_accuracy = torch.stack(total_accuracy).mean(0) + average_specifity = torch.stack(total_specifity_score).mean(0) + average_sensitivity = torch.stack(total_sensitivity_score).mean(0) + average_precision = torch.stack(total_precision_score).mean(0) + + return average_f1_score, average_accuracy, average_specifity, average_sensitivity, average_precision + + + + def evaluate_3D_tta(self, test_loader): + """Evaluation with metrics epoch and applying test-time augmentation + + Returns + ------- + epoch_f1_score: float + average test F1 score + + average_specifity: float + average test specifity + + average_sensitivity: float + average test sensitivity + + average_precision: float + average test precision + """ + self.model.eval() + total_f1_score = [] + total_accuracy = [] + total_specifity_score = [] + total_sensitivity_score = [] + total_precision_score = [] + + for idx, (image, label) in enumerate(tqdm(test_loader)): + + label = label.long() + image = image.float() + + with torch.no_grad(): + + output_normal = self.model(image.to(self.device)) + output_normal = output_normal.cpu() + + # augmentation + transformed_image, transform = self.tta_performer(image, 'lateral_flip') + transformed_image = transformed_image.to(self.device) + output = self.model(transformed_image) + output_back1 = transform(output[0].cpu()) + output_back1 = output_back1.unsqueeze(0) + + # augmentation + transformed_image, transform = self.tta_performer(image, 'interior_flip') + transformed_image = transformed_image.to(self.device) + output = self.model(transformed_image) + output_back5 = transform(output[0].cpu()) + output_back5 = output_back5.unsqueeze(0) + + # augmentation + transformed_image, transform = self.tta_performer(image, 'AWGN') + transformed_image = transformed_image.to(self.device) + output_back2 = self.model(transformed_image) + output_back2 = output_back2.cpu() + + # augmentation + transformed_image, transform = self.tta_performer(image, 'gamma') + transformed_image = transformed_image.to(self.device) + output_back3 = self.model(transformed_image) + output_back3 = output_back3.cpu() + + # augmentation + transformed_image, transform = self.tta_performer(image, 'blur') + transformed_image = transformed_image.to(self.device) + output_back4 = self.model(transformed_image) + output_back4 = output_back4.cpu() + + # ensembling the predictions + output = (output_normal + output_normal + output_back1 + output_back2 + + output_back3 + output_back4 ) / 6 + + output = output.to(self.device) + + output_sigmoided = F.sigmoid(output.permute(0, 2, 3, 4, 1)) + output_sigmoided = (output_sigmoided > 0.5).float() + + label = label.to(self.device) + + ############ Evaluation metric calculation ######## + # Metrics calculation (macro) over the whole set + confusioner = torchmetrics.ConfusionMatrix(num_classes=label.shape[1], multilabel=True).to(self.device) + confusion = confusioner(output_sigmoided.flatten(start_dim=0, end_dim=3), + label.permute(0, 2, 3, 4, 1).flatten(start_dim=0, end_dim=3)) + + F1_disease = [] + accuracy_disease = [] + specifity_disease = [] + sensitivity_disease = [] + precision_disease = [] + + for idx, disease in enumerate(confusion): + TN = disease[0, 0] + FP = disease[0, 1] + FN = disease[1, 0] + TP = disease[1, 1] + F1_disease.append(2 * TP / (2 * TP + FN + FP + epsilon)) + accuracy_disease.append((TP + TN) / (TP + TN + FP + FN + epsilon)) + specifity_disease.append(TN / (TN + FP + epsilon)) + sensitivity_disease.append(TP / (TP + FN + epsilon)) + precision_disease.append(TP / (TP + FP + epsilon)) + + # Macro averaging + total_f1_score.append(torch.stack(F1_disease)) + total_accuracy.append(torch.stack(accuracy_disease)) + total_specifity_score.append(torch.stack(specifity_disease)) + total_sensitivity_score.append(torch.stack(sensitivity_disease)) + total_precision_score.append(torch.stack(precision_disease)) + + average_f1_score = torch.stack(total_f1_score).mean(0) + average_accuracy = torch.stack(total_accuracy).mean(0) + average_specifity = torch.stack(total_specifity_score).mean(0) + average_sensitivity = torch.stack(total_sensitivity_score).mean(0) + average_precision = torch.stack(total_precision_score).mean(0) + + return average_f1_score, average_accuracy, average_specifity, average_sensitivity, average_precision + + + + + def predict_3D(self, image): + """Prediction of one signle image + + Returns + ------- + """ + self.model.eval() + + image = image.float() + image = image.to(self.device) + + with torch.no_grad(): + output = self.model(image) + output_sigmoided = F.sigmoid(output) + output_sigmoided = (output_sigmoided > 0.5).float() + + return output_sigmoided + + + + def predict_3D_tta(self, image): + """Prediction of one signle image using test-time augmentation + + Returns + ------- + """ + self.model.eval() + + image = image.float() + + with torch.no_grad(): + output_normal = self.model(image.to(self.device)) + output_normal = output_normal.cpu() + + # augmentation + transformed_image, transform = self.tta_performer(image, 'lateral_flip') + transformed_image = transformed_image.to(self.device) + output = self.model(transformed_image) + output_back1 = transform(output[0].cpu()) + output_back1 = output_back1.unsqueeze(0) + + # augmentation + transformed_image, transform = self.tta_performer(image, 'interior_flip') + transformed_image = transformed_image.to(self.device) + output = self.model(transformed_image) + output_back5 = transform(output[0].cpu()) + output_back5 = output_back5.unsqueeze(0) + + # augmentation + transformed_image, transform = self.tta_performer(image, 'AWGN') + transformed_image = transformed_image.to(self.device) + output_back2 = self.model(transformed_image) + output_back2 = output_back2.cpu() + + # augmentation + transformed_image, transform = self.tta_performer(image, 'gamma') + transformed_image = transformed_image.to(self.device) + output_back3 = self.model(transformed_image) + output_back3 = output_back3.cpu() + + # augmentation + transformed_image, transform = self.tta_performer(image, 'blur') + transformed_image = transformed_image.to(self.device) + output_back4 = self.model(transformed_image) + output_back4 = output_back4.cpu() + + # ensembling the predictions + output = (output_normal + output_normal + output_back1 + output_back2 + + output_back3 + output_back4 + output_back5) / 7 + + output = output.to(self.device) + + output_sigmoided = F.sigmoid(output) + output_sigmoided = (output_sigmoided > 0.5).float() + + return output_sigmoided + + + + def tta_performer(self, image, transform_type): + """applying test-time augmentation + """ + + if transform_type == 'lateral_flip': + transform = tio.transforms.RandomFlip(axes='L', flip_probability=1) + + if transform_type == 'interior_flip': + transform = tio.transforms.RandomFlip(axes='I', flip_probability=1) + + elif transform_type == 'AWGN': + transform = tio.RandomNoise(mean=self.params['augmentation']['mu_AWGN'], std=self.params['augmentation']['sigma_AWGN']) + + elif transform_type == 'gamma': + transform = tio.RandomGamma(log_gamma=(self.params['augmentation']['gamma_range'][0], self.params['augmentation']['gamma_range'][1])) + + elif transform_type == 'blur': + transform = tio.RandomBlur(std=(self.params['augmentation']['gamma_range'][0], self.params['augmentation']['gamma_range'][1])) + + # normalized_img = nib.Nifti1Image(image[0,0].numpy(), np.eye(4)) + # nib.save(normalized_img, 'orggg.nii.gz') + + trans_img = transform(image[0]) + # normalized_img = nib.Nifti1Image(trans_img[0].numpy(), np.eye(4)) + # nib.save(normalized_img, 'tta_img.nii.gz') + + # transform = tio.RandomAffine(scales=(1.05, 1.05), translation=0, degrees=0, default_pad_value='minimum', + # image_interpolation='nearest') + # image = transform(trans_img) + # normalized_img = nib.Nifti1Image(image[0].numpy(), np.eye(4)) + # nib.save(normalized_img, 'tta_img_back.nii.gz') + + # pdb.set_trace() + + return trans_img.unsqueeze(0), transform diff --git a/Train_Valid_brats.py b/Train_Valid_brats.py new file mode 100644 index 0000000..c9dda13 --- /dev/null +++ b/Train_Valid_brats.py @@ -0,0 +1,468 @@ +""" +Created on March 4, 2022. +Train_Valid_brats.py + +@author: Soroosh Tayebi Arasteh +https://github.com/tayebiarasteh/ +""" + +import os.path +import time +import pdb +from tensorboardX import SummaryWriter +import torch +from tqdm import tqdm +import torchmetrics +import torch.nn.functional as F + +from config.serde import read_config, write_config +from data.augmentation_brats import random_augment + +import warnings +warnings.filterwarnings('ignore') +epsilon = 1e-15 + + + +class Training: + def __init__(self, cfg_path, num_epochs=10, resume=False, augment=False): + """This class represents training and validation processes. + + Parameters + ---------- + cfg_path: str + Config file path of the experiment + + num_epochs: int + Total number of epochs for training + + resume: bool + if we are resuming training from a checkpoint + """ + self.params = read_config(cfg_path) + self.cfg_path = cfg_path + self.num_epochs = num_epochs + self.augment = augment + + if resume == False: + self.model_info = self.params['Network'] + self.epoch = 0 + self.step = 0 + self.best_loss = float('inf') + self.setup_cuda() + self.writer = SummaryWriter(log_dir=os.path.join(self.params['target_dir'], self.params['tb_logs_path'])) + + + def setup_cuda(self, cuda_device_id=0): + """setup the device. + + Parameters + ---------- + cuda_device_id: int + cuda device id + """ + if torch.cuda.is_available(): + torch.backends.cudnn.fastest = True + torch.cuda.set_device(cuda_device_id) + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + + def time_duration(self, start_time, end_time): + """calculating the duration of training or one iteration + + Parameters + ---------- + start_time: float + starting time of the operation + + end_time: float + ending time of the operation + + Returns + ------- + elapsed_hours: int + total hours part of the elapsed time + + elapsed_mins: int + total minutes part of the elapsed time + + elapsed_secs: int + total seconds part of the elapsed time + """ + elapsed_time = end_time - start_time + elapsed_hours = int(elapsed_time / 3600) + if elapsed_hours >= 1: + elapsed_mins = int((elapsed_time / 60) - (elapsed_hours * 60)) + elapsed_secs = int(elapsed_time - (elapsed_hours * 3600) - (elapsed_mins * 60)) + else: + elapsed_mins = int(elapsed_time / 60) + elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) + return elapsed_hours, elapsed_mins, elapsed_secs + + + def setup_model(self, model, optimiser, loss_function, weight=None): + """Setting up all the models, optimizers, and loss functions. + + Parameters + ---------- + model: model file + The network + + optimiser: optimizer file + The optimizer + + loss_function: loss file + The loss function + + weight: 1D tensor of float + class weights + """ + + # prints the network's total number of trainable parameters and + # stores it to the experiment config + total_param_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f'\nTotal # of trainable parameters: {total_param_num:,}') + print('----------------------------------------------------\n') + + self.model = model.to(self.device) + if not weight==None: + # self.loss_weight = weight.to(self.device) + # self.loss_function = loss_function(self.loss_weight) # for binary + # self.loss_function = loss_function(pos_weight=self.loss_weight) # for multi label + self.loss_function = loss_function() + else: + self.loss_function = loss_function() + self.optimiser = optimiser + + # Saves the model, optimiser,loss function name for writing to config file + # self.model_info['model'] = model.__name__ + # self.model_info['optimiser'] = optimiser.__name__ + self.model_info['total_param_num'] = total_param_num + self.model_info['loss_function'] = loss_function.__name__ + self.model_info['num_epochs'] = self.num_epochs + self.params['Network'] = self.model_info + write_config(self.params, self.cfg_path, sort_keys=True) + + + def load_checkpoint(self, model, optimiser, loss_function, weight=None): + """In case of resuming training from a checkpoint, + loads the weights for all the models, optimizers, and + loss functions, and device, tensorboard events, number + of iterations (epochs), and every info from checkpoint. + + Parameters + ---------- + model: model file + The network + + optimiser: optimizer file + The optimizer + + loss_function: loss file + The loss function + """ + checkpoint = torch.load(os.path.join(self.params['target_dir'], self.params['network_output_path'], + self.params['checkpoint_name'])) + self.device = None + self.model_info = checkpoint['model_info'] + self.setup_cuda() + self.model = model.to(self.device) + # self.loss_weight = weight.to(self.device) + # self.loss_function = loss_function(self.loss_weight) + self.loss_function = loss_function() + self.optimiser = optimiser + + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimiser.load_state_dict(checkpoint['optimizer_state_dict']) + self.epoch = checkpoint['epoch'] + self.step = checkpoint['step'] + self.best_loss = checkpoint['best_loss'] + self.writer = SummaryWriter(log_dir=os.path.join(os.path.join( + self.params['target_dir'], self.params['tb_logs_path'])), purge_step=self.step + 1) + + + + def train_epoch(self, train_loader, valid_loader=None): + """Training epoch + """ + self.params = read_config(self.cfg_path) + + for epoch in range(self.num_epochs - self.epoch): + self.epoch += 1 + + # initializing the loss list + batch_loss = 0 + batch_count = 0 + + start_time = time.time() + total_start_time = time.time() + + for idx, (image, label) in enumerate(train_loader): + self.model.train() + + # if we would like to have data augmentation during training + if self.augment: + image, label = random_augment(image, label, self.cfg_path) + + label = label.long() + image = image.float() + image = image.to(self.device) + label = label.to(self.device) + + self.optimiser.zero_grad() + + with torch.set_grad_enabled(True): + + output = self.model(image) + + # loss = self.loss_function(output, label[:, 0]) # for cross entropy loss + loss = self.loss_function(output, label) # for dice loss + + loss.backward() + self.optimiser.step() + + batch_loss += loss.item() + batch_count += 1 + self.step += 1 + + # Prints train loss after number of steps specified. + if (self.step) % self.params['display_train_loss_freq'] == 0: + end_time = time.time() + iteration_hours, iteration_mins, iteration_secs = self.time_duration(start_time, end_time) + total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) + train_loss = batch_loss / batch_count + batch_loss = 0 + batch_count = 0 + start_time = time.time() + + print('Step {} | train epoch {} | batch {} / {} | loss: {:.3f}'. + format(self.step, self.epoch, idx+1, len(train_loader), train_loss), + f'\ntime: {iteration_hours}h {iteration_mins}m {iteration_secs}s', + f'| total: {total_hours}h {total_mins}m {total_secs}s\n') + self.writer.add_scalar('Train_loss', train_loss, self.step) + + # Validation iteration & calculate metrics + if (self.step) % (self.params['display_stats_freq']) == 0: + + # saving the model, checkpoint, TensorBoard, etc. + if not valid_loader == None: + valid_loss, valid_F1, valid_accuracy, valid_specifity, valid_sensitivity, valid_precision = self.valid_epoch(valid_loader) + end_time = time.time() + total_hours, total_mins, total_secs = self.time_duration(total_start_time, end_time) + + self.calculate_tb_stats(valid_loss=valid_loss, valid_F1=valid_F1, valid_accuracy=valid_accuracy, valid_specifity=valid_specifity, + valid_sensitivity=valid_sensitivity, valid_precision=valid_precision) + self.savings_prints(iteration_hours, iteration_mins, iteration_secs, total_hours, + total_mins, total_secs, train_loss, valid_loss=valid_loss, + valid_F1=valid_F1, valid_accuracy=valid_accuracy, valid_specifity= valid_specifity, + valid_sensitivity=valid_sensitivity, valid_precision=valid_precision) + else: + self.savings_prints(iteration_hours, iteration_mins, iteration_secs, total_hours, + total_mins, total_secs, train_loss) + + + + def valid_epoch(self, valid_loader): + """Validation epoch + + """ + self.model.eval() + total_loss = 0.0 + total_f1_score = [] + total_accuracy = [] + total_specifity_score = [] + total_sensitivity_score = [] + total_precision_score = [] + + for idx, (image, label) in enumerate(valid_loader): + label = label.long() + image = image.float() + image = image.to(self.device) + label = label.to(self.device) + + with torch.no_grad(): + output = self.model(image) + # loss = self.loss_function(output, label[:, 0]) # for cross entropy loss + loss = self.loss_function(output, label) # for dice loss + + # max_preds = output.argmax(dim=1, keepdim=True) # get the index of the max probability (multi-class) + output_sigmoided = F.sigmoid(output.permute(0, 2, 3, 4, 1)) + output_sigmoided = (output_sigmoided > 0.5).float() + + ############ Evaluation metric calculation ######## + total_loss += loss.item() + + # Metrics calculation (macro) over the whole set + confusioner = torchmetrics.ConfusionMatrix(num_classes=label.shape[1], multilabel=True).to(self.device) + confusion = confusioner(output_sigmoided.flatten(start_dim=0, end_dim=3), label.permute(0, 2, 3, 4, 1).flatten(start_dim=0, end_dim=3)) + + F1_disease = [] + accuracy_disease = [] + specifity_disease = [] + sensitivity_disease = [] + precision_disease = [] + + for idx, disease in enumerate(confusion): + TN = disease[0, 0] + FP = disease[0, 1] + FN = disease[1, 0] + TP = disease[1, 1] + F1_disease.append(2 * TP / (2 * TP + FN + FP + epsilon)) + accuracy_disease.append((TP + TN) / (TP + TN + FP + FN + epsilon)) + specifity_disease.append(TN / (TN + FP + epsilon)) + sensitivity_disease.append(TP / (TP + FN + epsilon)) + precision_disease.append(TP / (TP + FP + epsilon)) + + # Macro averaging + total_f1_score.append(torch.stack(F1_disease)) + total_accuracy.append(torch.stack(accuracy_disease)) + total_specifity_score.append(torch.stack(specifity_disease)) + total_sensitivity_score.append(torch.stack(sensitivity_disease)) + total_precision_score.append(torch.stack(precision_disease)) + + average_loss = total_loss / len(valid_loader) + average_f1_score = torch.stack(total_f1_score).mean(0) + average_accuracy = torch.stack(total_accuracy).mean(0) + average_specifity = torch.stack(total_specifity_score).mean(0) + average_sensitivity = torch.stack(total_sensitivity_score).mean(0) + average_precision = torch.stack(total_precision_score).mean(0) + + return average_loss, average_f1_score, average_accuracy, average_specifity, average_sensitivity, average_precision + + + + def savings_prints(self, iteration_hours, iteration_mins, iteration_secs, total_hours, + total_mins, total_secs, train_loss, valid_loss=None, valid_F1=None, valid_accuracy=None, + valid_specifity=None, valid_sensitivity=None, valid_precision=None): + """Saving the model weights, checkpoint, information, + and training and validation loss and evaluation statistics. + + Parameters + ---------- + iteration_hours: int + hours part of the elapsed time of each iteration + + iteration_mins: int + minutes part of the elapsed time of each iteration + + iteration_secs: int + seconds part of the elapsed time of each iteration + + total_hours: int + hours part of the total elapsed time + + total_mins: int + minutes part of the total elapsed time + + total_secs: int + seconds part of the total elapsed time + + train_loss: float + training loss of the model + + valid_acc: float + validation accuracy of the model + + valid_sensitivity: float + validation sensitivity of the model + + valid_specifity: float + validation specifity of the model + + valid_loss: float + validation loss of the model + """ + + # Saves information about training to config file + self.params['Network']['num_epoch'] = self.epoch + self.params['Network']['step'] = self.step + write_config(self.params, self.cfg_path, sort_keys=True) + + # Saving the model based on the best loss + if valid_loss: + if valid_loss < self.best_loss: + self.best_loss = valid_loss + torch.save(self.model.state_dict(), os.path.join(self.params['target_dir'], + self.params['network_output_path'], self.params['trained_model_name'])) + else: + if train_loss < self.best_loss: + self.best_loss = train_loss + torch.save(self.model.state_dict(), os.path.join(self.params['target_dir'], + self.params['network_output_path'], self.params['trained_model_name'])) + + # Saving every couple of iterations + if (self.step) % self.params['network_save_freq'] == 0: + torch.save(self.model.state_dict(), os.path.join(self.params['target_dir'], self.params['network_output_path'], + 'step{}_'.format(self.step) + self.params['trained_model_name'])) + + # Save a checkpoint every iteration + if (self.step) % self.params['network_checkpoint_freq'] == 0: + torch.save({'epoch': self.epoch, 'step': self.step, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimiser.state_dict(), + 'loss_state_dict': self.loss_function.state_dict(), 'num_epochs': self.num_epochs, + 'model_info': self.model_info, 'best_loss': self.best_loss}, + os.path.join(self.params['target_dir'], self.params['network_output_path'], self.params['checkpoint_name'])) + + print('------------------------------------------------------' + '----------------------------------') + print(f'Step: {self.step} (epoch: {self.epoch}) | ' + f'iteration time: {iteration_hours}h {iteration_mins}m {iteration_secs}s | ' + f'total time: {total_hours}h {total_mins}m {total_secs}s') + print(f'\n\tTrain loss: {train_loss:.4f}') + + if valid_loss: + print(f'\t Val. loss: {valid_loss:.4f} | F1 (Dice score): {valid_F1.mean().item() * 100:.2f}% | accuracy: {valid_accuracy.mean().item() * 100:.2f}%' + f' | specifity: {valid_specifity.mean().item() * 100:.2f}%' + f' | recall (sensitivity): {valid_sensitivity.mean().item() * 100:.2f}% | precision: {valid_precision.mean().item() * 100:.2f}%\n') + + print('Individual F1 scores:') + print(f'Class 1: {valid_F1[0].item() * 100:.2f}%') + print(f'Class 2: {valid_F1[1].item() * 100:.2f}%') + print(f'Class 3: {valid_F1[2].item() * 100:.2f}%\n') + + # saving the training and validation stats + msg = f'----------------------------------------------------------------------------------------\n' \ + f'Step: {self.step} (epoch: {self.epoch}) | Iteration Time: {iteration_hours}h {iteration_mins}m {iteration_secs}s' \ + f' | total time: {total_hours}h {total_mins}m {total_secs}s\n\n\tTrain loss: {train_loss:.4f} | ' \ + f'Val. loss: {valid_loss:.4f} | F1 (Dice score): {valid_F1.mean().item() * 100:.2f}% | accuracy: {valid_accuracy.mean().item() * 100:.2f}% ' \ + f' | specifity: {valid_specifity.mean().item() * 100:.2f}%' \ + f' | recall (sensitivity): {valid_sensitivity.mean().item() * 100:.2f}% | precision: {valid_precision.mean().item() * 100:.2f}%\n\n' \ + f' | F1 class 1: {valid_F1[0].item() * 100:.2f}% | F1 class 2: {valid_F1[1].item() * 100:.2f}% | F1 class 3: {valid_F1[2].item() * 100:.2f}%\n\n' + else: + msg = f'----------------------------------------------------------------------------------------\n' \ + f'Step: {self.step} (epoch: {self.epoch}) | iteration time: {iteration_hours}h {iteration_mins}m {iteration_secs}s' \ + f' | total time: {total_hours}h {total_mins}m {total_secs}s\n\n\ttrain loss: {train_loss:.4f}\n\n' + with open(os.path.join(self.params['target_dir'], self.params['stat_log_path']) + '/Stats', 'a') as f: + f.write(msg) + + + + def calculate_tb_stats(self, valid_loss=None, valid_F1=None, valid_accuracy=None, valid_specifity=None, valid_sensitivity=None, valid_precision=None): + """Adds the evaluation metrics and loss values to the tensorboard. + + Parameters + ---------- + valid_acc: float + validation accuracy of the model + + valid_sensitivity: float + validation sensitivity of the model + + valid_specifity: float + validation specifity of the model + + valid_loss: float + validation loss of the model + """ + if valid_loss is not None: + self.writer.add_scalar('Valid_loss', valid_loss, self.step) + self.writer.add_scalar('Valid_specifity', valid_specifity.mean().item(), self.step) + self.writer.add_scalar('Valid_F1', valid_F1.mean().item(), self.step) + self.writer.add_scalar('Valid_F1 class 1', valid_F1[0].item(), self.step) + self.writer.add_scalar('Valid_F1 class 2', valid_F1[1].item(), self.step) + self.writer.add_scalar('Valid_F1 class 3', valid_F1[2].item(), self.step) + self.writer.add_scalar('Valid_precision', valid_precision.mean().item(), self.step) + self.writer.add_scalar('Valid_recall_sensitivity', valid_sensitivity.mean().item(), self.step) \ No newline at end of file diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..a50b74e --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,48 @@ +Network: + batch_size: 1 + lr: 1e-4 + weight_decay: 0 + amsgrad: False +# resize_shape: [128,150,180] # d, h, w (training data are also stored in d, h, w) +# resize_shape: [100,160,160] # d, h, w (training data are also stored in d, h, w). 100*160*160 = 5.5 GB +# resize_shape: [115,180,180] # d, h, w (training data are also stored in d, h, w). 115*180*180 (75% of original) = 6 GB +# resize_shape: [24,24,24] # d, h, w (training data are also stored in d, h, w). after cropping. multi mode 4-level. 4.7 GB for 48; 3.4 GB for 24 + resize_shape: [80,80,80] # d, h, w (training data are also stored in d, h, w). after cropping. multi mode 4-level. 3.4 GB for 24 +augmentation: + general_spatial_probability: 0.5 # probability of having spatial data augmentation at all + general_intensity_probability: 0.5 # probability of having intensity data augmentation at all + flip_prob: 0.2 # probability of having flipping augmentation + zoom_range: [1, 1.15] # 0.1 means from 0.9 to 1.1 + zoom_prob: 0.2 # probability of having zooming augmentation [for the moment only zoom in: more than 1] + rotation_range: 5 # degrees. 1 = (-1, 1) + rotation_prob: -1 # probability of having rotation augmentation [DON'T USE FOR BRATS] + shift_range: 1 # pixels. 1 = (-1, 1) + shift_prob: -1 # probability of having translation augmentation [DON'T USE FOR BRATS] + eladf_control_points: 7 + eladf_max_displacement: 8 + elastic_prob: -1 # probability of having elastic deformation augmentation [DON'T USE FOR BRATS] + gamma_range: [0, 0.2] # this number changes beta exponent distribution; don't go more than 0.2 + gamma_prob: 0.2 # probability of having gamma augmentation + mu_AWGN: 0.0 # mean of AWGN + sigma_AWGN: 0.1 # std of AWGN + AWGN_prob: 0.2 # probability of having AWGN augmentation + motion_prob: -1 # probability of having random motion augmentation [DON'T USE FOR BRATS] + ghosting_prob: -1 # probability of having ghosting augmentation + blurring_range: [0, 1] # range of std of Gaussian filter; don't go more than 1 + blurring_prob: 0.2 # probability of having blurring augmentation +network_output_path: network_data/ +output_data_path: output_data/ +tb_logs_path: tensor_board_logs/ +stat_log_path: stat_logs +checkpoint_name: checkpoint.tar +trained_model_name: trained_model.pth + + +# changeable items: +file_path: /home/soroosh/Documents/datasets/BraTS20/old_BraTS20/cropped/ +target_dir: /home/soroosh/Documents/Repositories_target_files/federated_he/ +network_save_freq: 600 # based on the batch size, shows the number of batches done; every half an hour is good +display_stats_freq: 200 # valid freq is equal to this +display_train_loss_freq: 20 +network_checkpoint_freq: 1 +num_epochs: 250 \ No newline at end of file diff --git a/config/serde.py b/config/serde.py new file mode 100644 index 0000000..4c9c0ab --- /dev/null +++ b/config/serde.py @@ -0,0 +1,77 @@ +""" +Created on November 10, 2019 +functions for writing/reading data to/from disk + +@modified_by: Soroosh Tayebi Arasteh +""" +import yaml +import numpy as np +import os +import warnings +import shutil +import pdb + + + + +def read_config(config_path): + """Reads config file in yaml format into a dictionary + + Parameters + ---------- + config_path: str + Path to the config file in yaml format + + Returns + ------- + config dictionary + """ + + with open(config_path, 'rb') as yaml_file: + return yaml.safe_load(yaml_file) + + +def write_config(params, cfg_path, sort_keys=False): + with open(cfg_path, 'w') as f: + yaml.dump(params, f) + + +def create_experiment(experiment_name, global_config_path): + params = read_config(global_config_path) + params['experiment_name'] = experiment_name + create_experiment_folders(params) + cfg_file_name = params['experiment_name'] + '_config.yaml' + cfg_path = os.path.join(os.path.join(params['target_dir'], params['network_output_path']), cfg_file_name) + params['cfg_path'] = cfg_path + write_config(params, cfg_path) + return params + + +def create_experiment_folders(params): + try: + path_keynames = ["network_output_path", "tb_logs_path", "stat_log_path", "output_data_path"] + for key in path_keynames: + params[key] = os.path.join(params['experiment_name'], params[key]) + os.makedirs(os.path.join(params['target_dir'], params[key])) + except: + raise Exception("Experiment already exist. Please try a different experiment name") + + +def open_experiment(experiment_name, global_config_path): + """Open Existing Experiments + """ + default_params = read_config(global_config_path) + cfg_file_name = experiment_name + '_config.yaml' + cfg_path = os.path.join(os.path.join(default_params['target_dir'], experiment_name, default_params['network_output_path']), cfg_file_name) + params = read_config(cfg_path) + return params + + +def delete_experiment(experiment_name, global_config_path): + """Delete Existing Experiment folder + """ + default_params = read_config(global_config_path) + cfg_file_name = experiment_name + '_config.yaml' + cfg_path = os.path.join(os.path.join(default_params['target_dir'], experiment_name, default_params['network_output_path']), cfg_file_name) + params = read_config(cfg_path) + shutil.rmtree(os.path.join(params['target_dir'], experiment_name)) diff --git a/data/augmentation_brats.py b/data/augmentation_brats.py new file mode 100644 index 0000000..4ffdf7a --- /dev/null +++ b/data/augmentation_brats.py @@ -0,0 +1,146 @@ +""" +Created on March 7, 2022. +augmentation_brats.py + +@author: Soroosh Tayebi Arasteh +https://github.com/tayebiarasteh/ +""" + + +import pdb +import torchio as tio +from random import random + +from config.serde import read_config + + + + +def random_spatial_brats_augmentation(image, label, confg_path='/home/soroosh/Documents/Repositories/federated_he/config/config.yaml'): + """Both image and the label should be augmented + 1. Random flip + 2, 3, 4. Random affine (zoom, rotation, shift) + - scales (zoom): Tuple (a1, b1, a2, b2, a3, b3) defining the scaling ranges. a1 to b1 range for the first dimension + 0.1 means from 0.9 to 1.1 = (0.9, 1.1) + - degrees (rotation): Tuple (a1, b1, a2, b2, a3, b3) defining the rotation ranges in degrees. a1 to b1 range for the first dimension + - translation (shift): Tuple (a1, b1, a2, b2, a3, b3) defining the translation ranges in mm. a1 to b1 range for the first dimension + - image_interpolation: 'nearest', 'linear', 'bspline', 'lanczos'. For the label we must choose 'nearest'. + - default_pad_value: 'mean', 'minimum'. For the label we must choose 'minimum'. + 5. Random Elastic deformation + - num_control_points: Number of control points along each dimension of the coarse grid (nx, ny, nz). + Smaller numbers generate smoother deformations. + The minimum number of control points is 4 as this transform uses cubic B-splines to interpolate displacement. + - max_displacement (8 is good for brats): Maximum displacement along each dimension at each control point + - locked_borders: If 0, all displacement vectors are kept. + If 1, displacement of control points at the border of the coarse grid will be set to 0. + If 2, displacement of control points at the border of the image (red dots in the image below) will also be set to 0. + Compose + first do flipping, then affine, then elastic + for brats don't do flipping. + elastic deformation important for brats + rotation only one degree for brats + intensity augmentations are more important for brats + Parameters + ---------- + image: torch tensor (n, c, d, h, w) + label: torch tensor (n, c, d, h, w) + confg_path: str + + Returns + ------- + transformed_image.unsqueeze(0): torch tensor (n, c, d, h, w) + transformed_label.unsqueeze(0): torch tensor (n, c, d, h, w) + """ + params = read_config(confg_path) + + transform = tio.transforms.RandomFlip(axes='L', flip_probability=params['augmentation']['flip_prob']) + image = transform(image) + label = transform(label) + + if random() < params['augmentation']['elastic_prob']: + transform = tio.RandomElasticDeformation(num_control_points=(params['augmentation']['eladf_control_points']), + max_displacement=(params['augmentation']['eladf_max_displacement']), + locked_borders=2, image_interpolation='nearest') + image = transform(image) + label = transform(label) + return image, label + + if random() < params['augmentation']['zoom_prob']: + transform = tio.RandomAffine(scales=(params['augmentation']['zoom_range'][0], params['augmentation']['zoom_range'][1]), default_pad_value='minimum', + translation=(0,0,0), degrees=(0,0,0), image_interpolation='nearest') + image = transform(image) + label = transform(label) + + if random() < params['augmentation']['rotation_prob']: + transform = tio.RandomAffine(degrees=(params['augmentation']['rotation_range']), default_pad_value='minimum', + image_interpolation='nearest') + image = transform(image) + label = transform(label) + + if random() < params['augmentation']['shift_prob']: + transform = tio.RandomAffine(translation=(params['augmentation']['shift_range']), default_pad_value='minimum', + image_interpolation='nearest') + image = transform(image) + label = transform(label) + + assert len(label.unique()) == 2 + return image, label + + + +def random_intensity_brats_augmentation(image, confg_path='/home/soroosh/Documents/Repositories/federated_he/config/config.yaml'): + """Only image should be augmented + """ + params = read_config(confg_path) + + # additive Gaussian noise (not needed for min max normalization; 20% prob for the mean std normalization) + if random() < params['augmentation']['AWGN_prob']: + transform = tio.RandomNoise(mean=params['augmentation']['mu_AWGN'], std=params['augmentation']['sigma_AWGN']) + return transform(image) + + if random() < params['augmentation']['gamma_prob']: + transform = tio.RandomGamma(log_gamma=(params['augmentation']['gamma_range'][0], params['augmentation']['gamma_range'][1])) + return transform(image) + + if random() < params['augmentation']['motion_prob']: + transform = tio.RandomMotion(degrees=10, translation=10, num_transforms=2) + return transform(image) + + if random() < params['augmentation']['ghosting_prob']: + transform = tio.RandomGhosting(num_ghosts=10, axes=(0, 1, 3), intensity=0.3) + return transform(image) + + if random() < params['augmentation']['blurring_prob']: + transform = tio.RandomBlur(std=(params['augmentation']['gamma_range'][0], params['augmentation']['gamma_range'][1])) + return transform(image) + + return image + + + + + +def random_augment(image, label, confg_path='/home/soroosh/Documents/Repositories/federated_he/config/config.yaml'): + """ + Parameters + ---------- + image: torch tensor (n, c, d, h, w) + label: torch tensor (n, c, d, h, w) + confg_path: str + Returns + ------- + transformed_image: torch tensor (n, c, d, h, w) + transformed_label: torch tensor (n, c, d, h, w) + """ + params = read_config(confg_path) + + if random() < params['augmentation']['general_spatial_probability']: + transformed_image, transformed_label = random_spatial_brats_augmentation(image[0], label[0], confg_path) + return transformed_image.unsqueeze(0), transformed_label.unsqueeze(0) + + elif random() < params['augmentation']['general_intensity_probability']: + transformed_image = random_intensity_brats_augmentation(image[0], confg_path) + return transformed_image.unsqueeze(0), label + + else: + return image, label \ No newline at end of file diff --git a/data/csv_data_preprocess_brats.py b/data/csv_data_preprocess_brats.py new file mode 100755 index 0000000..990556f --- /dev/null +++ b/data/csv_data_preprocess_brats.py @@ -0,0 +1,285 @@ +""" +Created on March 4, 2022. +csv_data_preprocess_brats.py + +creating a master list for Brats dataset. + +@author: Soroosh Tayebi Arasteh +https://github.com/tayebiarasteh/ +""" + +import os +import pdb +import numpy as np +import pandas as pd +from tqdm import tqdm +import h5py +import nibabel as nib + +from config.serde import read_config + +import warnings +warnings.filterwarnings('ignore') + + + + +class csv_preprocess_brats(): + def __init__(self, cfg_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml"): + self.params = read_config(cfg_path) + + def hd5_to_nifti(self): + """Converting HD5 data to nifti + we have 4 MRI modalities and 3 labels per patient + """ + + # filee = "/home/soroosh/Downloads/BraTS2020_training_data/content/data/volume_1_slice_40.h5" + org_df = "/home/soroosh/Downloads/BraTS2020_training_data/content/data/meta_data.csv" + target_base = "/home/soroosh/Documents/datasets/BraTS20/" + base_dir = '/home/soroosh/Downloads/BraTS2020_training_data/content/data' + + df = pd.read_csv(org_df, sep=',') + patient_list = df['volume'].unique().tolist() + + + for patient in tqdm(patient_list): + df_pat = df[df['volume'] == patient] + volume_mod1 = [] + volume_mod2 = [] + volume_mod3 = [] + volume_mod4 = [] + volume_mask1 = [] + volume_mask2 = [] + volume_mask3 = [] + + for i in range(len(df_pat)): + rel_path = df_pat[df_pat['slice'] == i]['slice_path'].values[0] + path = os.path.join(base_dir, os.path.basename(rel_path)) + hf = h5py.File(path, 'r') + volume_mod1.append(hf['image'][:, :, 0]) + volume_mod2.append(hf['image'][:, :, 1]) + volume_mod3.append(hf['image'][:, :, 2]) + volume_mod4.append(hf['image'][:, :, 3]) + volume_mask1.append(hf['mask'][:, :, 0]) + volume_mask2.append(hf['mask'][:, :, 1]) + volume_mask3.append(hf['mask'][:, :, 2]) + + volume_mod1 = np.stack(volume_mod1) # (d, h, w) + volume_mod2 = np.stack(volume_mod2) # (d, h, w) + volume_mod3 = np.stack(volume_mod3) # (d, h, w) + volume_mod4 = np.stack(volume_mod4) # (d, h, w) + volume_mask1 = np.stack(volume_mask1) # (d, h, w) + volume_mask2 = np.stack(volume_mask2) # (d, h, w) + volume_mask3 = np.stack(volume_mask3) # (d, h, w) + + input_img = nib.Nifti1Image(volume_mod1, np.eye(4)) + os.makedirs(os.path.join(target_base, 'T1', 'pat' + str(patient).zfill(3)), exist_ok=True) + nib.save(input_img, os.path.join(target_base, 'T1', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-mod1.nii.gz')) + + input_img = nib.Nifti1Image(volume_mod2, np.eye(4)) + os.makedirs(os.path.join(target_base, 'T1Gd', 'pat' + str(patient).zfill(3)), exist_ok=True) + nib.save(input_img, os.path.join(target_base, 'T1Gd', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-mod2.nii.gz')) + + input_img = nib.Nifti1Image(volume_mod3, np.eye(4)) + os.makedirs(os.path.join(target_base, 'T2', 'pat' + str(patient).zfill(3)), exist_ok=True) + nib.save(input_img, os.path.join(target_base, 'T2', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-mod3.nii.gz')) + + input_img = nib.Nifti1Image(volume_mod4, np.eye(4)) + os.makedirs(os.path.join(target_base, 'T2-FLAIR', 'pat' + str(patient).zfill(3)), exist_ok=True) + nib.save(input_img, os.path.join(target_base, 'T2-FLAIR', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-mod4.nii.gz')) + + input_img = nib.Nifti1Image(volume_mask1, np.eye(4)) + nib.save(input_img, os.path.join(target_base, 'T1', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label1.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T1Gd', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label1.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T2', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label1.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T2-FLAIR', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label1.nii.gz')) + + input_img = nib.Nifti1Image(volume_mask2, np.eye(4)) + nib.save(input_img, os.path.join(target_base, 'T1', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label2.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T1Gd', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label2.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T2', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label2.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T2-FLAIR', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label2.nii.gz')) + + input_img = nib.Nifti1Image(volume_mask3, np.eye(4)) + nib.save(input_img, os.path.join(target_base, 'T1', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label3.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T1Gd', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label3.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T2', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label3.nii.gz')) + nib.save(input_img, os.path.join(target_base, 'T2-FLAIR', 'pat' + str(patient).zfill(3), 'pat' + str(patient).zfill(3) + '-seg-label3.nii.gz')) + + +class cropper(): + def __init__(self, cfg_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml"): + """ + Cropping the all the images and segmentations around the brain + Parameters + ---------- + cfg_path + """ + self.params = read_config(cfg_path) + self.file_base_dir = self.params['file_path'] + org_df = pd.read_csv(os.path.join(self.file_base_dir, "brats20_master_list.csv"), sep=',') + self.df = org_df[org_df['pat_num'] > 72] + + + + def create_nonzero_mask(self, data): + from scipy.ndimage import binary_fill_holes + assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)" + nonzero_mask = np.zeros(data.shape[1:], dtype=bool) + for c in range(data.shape[0]): + this_mask = data[c] != 0 + nonzero_mask = nonzero_mask | this_mask + nonzero_mask = binary_fill_holes(nonzero_mask) + return nonzero_mask + + + def get_bbox_from_mask(self, mask, outside_value=0): + mask_voxel_coords = np.where(mask != outside_value) + minzidx = int(np.min(mask_voxel_coords[0])) + maxzidx = int(np.max(mask_voxel_coords[0])) + 1 + minxidx = int(np.min(mask_voxel_coords[1])) + maxxidx = int(np.max(mask_voxel_coords[1])) + 1 + minyidx = int(np.min(mask_voxel_coords[2])) + maxyidx = int(np.max(mask_voxel_coords[2])) + 1 + return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]] + + + def crop_to_bbox(self, image, bbox): + assert len(image.shape) == 3, "only supports 3d images" + resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1])) + return image[resizer] + + + def perform_cropping(self): + + for index, row in tqdm(self.df.iterrows()): + path_pat = os.path.join(self.file_base_dir, 'pat' + str(row['pat_num']).zfill(3)) + path_file = os.path.join(path_pat, 'pat' + str(row['pat_num']).zfill(3) + '-mod1.nii.gz') + x_input_nifti = nib.load(path_file) + data = x_input_nifti.get_fdata() + data = np.expand_dims(data, 0) + nonzero_mask = self.create_nonzero_mask(data) + bbox = self.get_bbox_from_mask(nonzero_mask, 0) + # bbox = [[0, 148], [41, 190], [35, 220]] + bbox[1:] = [[41, 190], [35, 220]] + cropped_data = [] + for c in range(data.shape[0]): + cropped = self.crop_to_bbox(data[c], bbox) + cropped_data.append(cropped[None]) + data = np.vstack(cropped_data) + x_input_nifti.header['dim'][1:4] = np.array(data[0].shape) + x_input_nifti.affine[0, 3] = bbox[0][0] + 1 + x_input_nifti.affine[1, 3] = bbox[1][0] + 1 + x_input_nifti.affine[2, 3] = bbox[2][0] + 1 + resultt = nib.Nifti1Image(data[0], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(resultt, path_file) + + # mod2 + path_file = path_file.replace('mod1', 'mod2') + x_input_nifti = nib.load(path_file) + data = x_input_nifti.get_fdata() + data = np.expand_dims(data, 0) + cropped_data = [] + for c in range(data.shape[0]): + cropped = self.crop_to_bbox(data[c], bbox) + cropped_data.append(cropped[None]) + data = np.vstack(cropped_data) + x_input_nifti.header['dim'][1:4] = np.array(data[0].shape) + x_input_nifti.affine[0, 3] = bbox[0][0] + 1 + x_input_nifti.affine[1, 3] = bbox[1][0] + 1 + x_input_nifti.affine[2, 3] = bbox[2][0] + 1 + resultt = nib.Nifti1Image(data[0], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(resultt, path_file) + + # mod3 + path_file = path_file.replace('mod2', 'mod3') + x_input_nifti = nib.load(path_file) + data = x_input_nifti.get_fdata() + data = np.expand_dims(data, 0) + cropped_data = [] + for c in range(data.shape[0]): + cropped = self.crop_to_bbox(data[c], bbox) + cropped_data.append(cropped[None]) + data = np.vstack(cropped_data) + x_input_nifti.header['dim'][1:4] = np.array(data[0].shape) + x_input_nifti.affine[0, 3] = bbox[0][0] + 1 + x_input_nifti.affine[1, 3] = bbox[1][0] + 1 + x_input_nifti.affine[2, 3] = bbox[2][0] + 1 + resultt = nib.Nifti1Image(data[0], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(resultt, path_file) + + # mod4 + path_file = path_file.replace('mod3', 'mod4') + x_input_nifti = nib.load(path_file) + data = x_input_nifti.get_fdata() + data = np.expand_dims(data, 0) + cropped_data = [] + for c in range(data.shape[0]): + cropped = self.crop_to_bbox(data[c], bbox) + cropped_data.append(cropped[None]) + data = np.vstack(cropped_data) + x_input_nifti.header['dim'][1:4] = np.array(data[0].shape) + x_input_nifti.affine[0, 3] = bbox[0][0] + 1 + x_input_nifti.affine[1, 3] = bbox[1][0] + 1 + x_input_nifti.affine[2, 3] = bbox[2][0] + 1 + resultt = nib.Nifti1Image(data[0], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(resultt, path_file) + + # seg-label1 + path_file = path_file.replace('mod4', 'seg-label1') + x_input_nifti = nib.load(path_file) + data = x_input_nifti.get_fdata() + data = np.expand_dims(data, 0) + cropped_data = [] + for c in range(data.shape[0]): + cropped = self.crop_to_bbox(data[c], bbox) + cropped_data.append(cropped[None]) + data = np.vstack(cropped_data) + x_input_nifti.header['dim'][1:4] = np.array(data[0].shape) + x_input_nifti.affine[0, 3] = bbox[0][0] + 1 + x_input_nifti.affine[1, 3] = bbox[1][0] + 1 + x_input_nifti.affine[2, 3] = bbox[2][0] + 1 + resultt = nib.Nifti1Image(data[0], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(resultt, path_file) + + # seg-label2 + path_file = path_file.replace('seg-label1', 'seg-label2') + x_input_nifti = nib.load(path_file) + data = x_input_nifti.get_fdata() + data = np.expand_dims(data, 0) + cropped_data = [] + for c in range(data.shape[0]): + cropped = self.crop_to_bbox(data[c], bbox) + cropped_data.append(cropped[None]) + data = np.vstack(cropped_data) + x_input_nifti.header['dim'][1:4] = np.array(data[0].shape) + x_input_nifti.affine[0, 3] = bbox[0][0] + 1 + x_input_nifti.affine[1, 3] = bbox[1][0] + 1 + x_input_nifti.affine[2, 3] = bbox[2][0] + 1 + resultt = nib.Nifti1Image(data[0], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(resultt, path_file) + + # seg-label3 + path_file = path_file.replace('seg-label2', 'seg-label3') + x_input_nifti = nib.load(path_file) + data = x_input_nifti.get_fdata() + data = np.expand_dims(data, 0) + cropped_data = [] + for c in range(data.shape[0]): + cropped = self.crop_to_bbox(data[c], bbox) + cropped_data.append(cropped[None]) + data = np.vstack(cropped_data) + x_input_nifti.header['dim'][1:4] = np.array(data[0].shape) + x_input_nifti.affine[0, 3] = bbox[0][0] + 1 + x_input_nifti.affine[1, 3] = bbox[1][0] + 1 + x_input_nifti.affine[2, 3] = bbox[2][0] + 1 + resultt = nib.Nifti1Image(data[0], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(resultt, path_file) + + + +if __name__ == '__main__': + # handler = csv_preprocess_mimic() + # handler.hd5_to_nifti() + crroppper = cropper() + crroppper = crroppper.perform_cropping() diff --git a/data/data_provider_brats.py b/data/data_provider_brats.py new file mode 100644 index 0000000..6844bb8 --- /dev/null +++ b/data/data_provider_brats.py @@ -0,0 +1,414 @@ +""" +Created on March 4, 2022. +data_provider_brats.py + +@author: Soroosh Tayebi Arasteh +https://github.com/tayebiarasteh/ +""" + +import os +import torch +import pdb +import pandas as pd +import numpy as np +from skimage.io import imsave +from torch.utils.data import Dataset +import nibabel as nib +from scipy.ndimage.interpolation import zoom +from copy import deepcopy + +from config.serde import read_config + + + + + + +class data_loader_3D(Dataset): + """ + This is the pipeline based on Pytorch's Dataset and Dataloader + """ + def __init__(self, cfg_path, mode='train', modality=2, multimodal=True): + """ + Parameters + ---------- + cfg_path: str + Config file path of the experiment + + mode: str + Nature of operation to be done with the data. + Possible inputs are train, valid, test + Default value: train + + modality: int + modality of the MR sequence + 1: T1 + 2: T1Gd + 3: T2 + 4: T2-FLAIR + """ + + self.cfg_path = cfg_path + self.params = read_config(cfg_path) + self.file_base_dir = self.params['file_path'] + self.modality = int(modality) + self.multimodal = multimodal + org_df = pd.read_csv(os.path.join(self.file_base_dir, "brats20_master_list.csv"), sep=',') + + if mode=='train': + self.subset_df = org_df[org_df['soroosh_split'] == 'train'] + elif mode == 'valid': + self.subset_df = org_df[org_df['soroosh_split'] == 'valid'] + elif mode == 'test': + self.subset_df = org_df[org_df['soroosh_split'] == 'test'] + + self.file_path_list = list(self.subset_df['pat_num']) + + + + def __len__(self): + """Returns the length of the dataset""" + return len(self.file_path_list) + + + def __getitem__(self, idx): + """ + Parameters + ---------- + idx: int + + Returns + ------- + img: torch tensor + label: torch tensor + """ + path_pat = os.path.join(self.file_base_dir, 'pat' + str(self.file_path_list[idx]).zfill(3)) + path_file = os.path.join(path_pat, 'pat' + str(self.file_path_list[idx]).zfill(3) + '-mod' + str(self.modality) + '.nii.gz') + img = nib.load(path_file).get_fdata() + img = img.astype(np.float32) # (d, h, w) + + label_path1 = path_file.replace('-mod' + str(self.modality), '-seg-label1') + label1 = nib.load(label_path1).get_fdata() # (d, h, w) + label_path2 = path_file.replace('-mod' + str(self.modality), '-seg-label2') + label2 = nib.load(label_path2).get_fdata() # (d, h, w) + label_path3 = path_file.replace('-mod' + str(self.modality), '-seg-label3') + label3 = nib.load(label_path3).get_fdata() # (d, h, w) + + if self.multimodal: + path_file1 = path_file.replace('-mod' + str(self.modality), '-mod1') + img1 = nib.load(path_file1).get_fdata() # (d, h, w) + # normalization + normalized_img1 = self.irm_min_max_preprocess(img1.transpose(1, 2, 0)) # (h, w, d) + normalized_img1 = normalized_img1.transpose(2, 0, 1) # (d, h, w) + + path_file2 = path_file.replace('-mod' + str(self.modality), '-mod2') + img2 = nib.load(path_file2).get_fdata() # (d, h, w) + # normalization + normalized_img2 = self.irm_min_max_preprocess(img2.transpose(1, 2, 0)) # (h, w, d) + normalized_img2 = normalized_img2.transpose(2, 0, 1) # (d, h, w) + + path_file3 = path_file.replace('-mod' + str(self.modality), '-mod3') + img3 = nib.load(path_file3).get_fdata() # (d, h, w) + # normalization + normalized_img3 = self.irm_min_max_preprocess(img3.transpose(1, 2, 0)) # (h, w, d) + normalized_img3 = normalized_img3.transpose(2, 0, 1) # (d, h, w) + + path_file4 = path_file.replace('-mod' + str(self.modality), '-mod4') + img4 = nib.load(path_file4).get_fdata() # (d, h, w) + # normalization + normalized_img4 = self.irm_min_max_preprocess(img4.transpose(1, 2, 0)) # (h, w, d) + normalized_img4 = normalized_img4.transpose(2, 0, 1) # (d, h, w) + + # image resizing for memory issues + if normalized_img1.size > 2433600: # 100*156*156 = 2433600 + normalized_img_resized1, label1 = self.resize_manual(normalized_img1, label1) + normalized_img_resized2, label2 = self.resize_manual(normalized_img2, label2) + normalized_img_resized3, label3 = self.resize_manual(normalized_img3, label3) + normalized_img_resized4, _ = self.resize_manual(normalized_img4, label3) + + normalized_img_resized = np.stack((normalized_img_resized1, normalized_img_resized2, + normalized_img_resized3, normalized_img_resized4)) # (c=4, d, h, w) + else: + normalized_img_resized = np.stack((normalized_img1, normalized_img2, + normalized_img3, normalized_img4)) # (c=4, d, h, w) + normalized_img_resized = torch.from_numpy(normalized_img_resized) # (c=4, d, h, w) + + else: + # normalization + normalized_img = self.irm_min_max_preprocess(img.transpose(1, 2, 0)) # (h, w, d) + normalized_img = normalized_img.transpose(2, 0, 1) # (d, h, w) + + # image resizing for memory issues + if normalized_img.size > 2433600: # 100*156*156 = 2433600 + normalized_img_resized, label1 = self.resize_manual(normalized_img, label1) + _, label2 = self.resize_manual(normalized_img, label2) + _, label3 = self.resize_manual(normalized_img, label3) + else: + normalized_img_resized = normalized_img + + normalized_img_resized = torch.from_numpy(normalized_img_resized) # (d, h, w) + normalized_img_resized = torch.unsqueeze(normalized_img_resized, 0) # (c=1, d, h, w) + + label1 = torch.from_numpy(label1) # (d, h, w) + label2 = torch.from_numpy(label2) # (d, h, w) + label3 = torch.from_numpy(label3) # (d, h, w) + label = torch.stack((label1, label2, label3)) # (c=3, d, h, w) + + return normalized_img_resized, label + + + + def data_normalization_mean_std(self, image): + """subtarcting mean and std for each individual patient and modality + mean and std only over the tumor region + + Parameters + ---------- + image: numpy array + The raw input image + Returns + ------- + normalized_img: numpy array + The normalized image + """ + mean = image[image > 0].mean() + std = image[image > 0].std() + + if self.outzero_normalization: + image[image < 0] = -1000 + + normalized_img = (image - mean) / std + + if self.outzero_normalization: + normalized_img[normalized_img < -100] = 0 + + return normalized_img + + + + def irm_min_max_preprocess(self, image, low_perc=1, high_perc=99): + """Main pre-processing function used for the challenge (seems to work the best). + Remove outliers voxels first, then min-max scale. + Warnings + -------- + This will not do it channel wise!! + """ + non_zeros = image > 0 + low, high = np.percentile(image[non_zeros], [low_perc, high_perc]) + image = np.clip(image, low, high) + + min_ = np.min(image) + max_ = np.max(image) + scale = max_ - min_ + image = (image - min_) / scale + + return image + + + def resize_manual(self, img, gt): + """Downsampling of the image and its label. + Parameters + ---------- + img: numpy array + input image + gt: numpy array + input label + Returns + ------- + img: numpy array + downsampled image + gt: numpy array + downsampled label + """ + resize_ratio = np.divide(tuple(self.params['Network']['resize_shape']), img.shape) + img = zoom(img, resize_ratio, order=2) + gt = zoom(gt, resize_ratio, order=0) + return img, gt + + + + + + +class data_loader_without_label_3D(): + """ + This is the dataloader based on our own implementation. + """ + def __init__(self, cfg_path, mode='test', modality=2, multimodal=True): + """ + Parameters + ---------- + cfg_path: str + Config file path of the experiment + mode: str + Nature of operation to be done with the data. + Possible inputs are train, valid, test + Default value: train + + modality: int + modality of the MR sequence + 1: T1 + 2: T1Gd + 3: T2 + 4: T2-FLAIR + """ + + self.cfg_path = cfg_path + self.params = read_config(cfg_path) + self.file_base_dir = self.params['file_path'] + self.modality = int(modality) + self.multimodal = multimodal + org_df = pd.read_csv(os.path.join(self.file_base_dir, "brats20_master_list.csv"), sep=',') + + if mode=='train': + self.subset_df = org_df[org_df['soroosh_split'] == 'train'] + elif mode == 'valid': + self.subset_df = org_df[org_df['soroosh_split'] == 'valid'] + elif mode == 'test': + self.subset_df = org_df[org_df['soroosh_split'] == 'test'] + + self.file_path_list = list(self.subset_df['pat_num']) + + + + def provide_test_without_label(self, file_path): + """test data provider for prediction + Returns + ---------- + """ + img_nifti = nib.load(file_path) + img = img_nifti.get_fdata() + img = img.astype(np.float32) # (d, h, w) + + if self.multimodal: + path_file1 = file_path.replace('-mod' + str(self.modality), '-mod1') + img1 = nib.load(path_file1).get_fdata() # (d, h, w) + # normalization + normalized_img1 = self.irm_min_max_preprocess(img1.transpose(1, 2, 0)) # (h, w, d) + normalized_img1 = normalized_img1.transpose(2, 0, 1) # (d, h, w) + + path_file2 = file_path.replace('-mod' + str(self.modality), '-mod2') + img2 = nib.load(path_file2).get_fdata() # (d, h, w) + # normalization + normalized_img2 = self.irm_min_max_preprocess(img2.transpose(1, 2, 0)) # (h, w, d) + normalized_img2 = normalized_img2.transpose(2, 0, 1) # (d, h, w) + + path_file3 = file_path.replace('-mod' + str(self.modality), '-mod3') + img3 = nib.load(path_file3).get_fdata() # (d, h, w) + # normalization + normalized_img3 = self.irm_min_max_preprocess(img3.transpose(1, 2, 0)) # (h, w, d) + normalized_img3 = normalized_img3.transpose(2, 0, 1) # (d, h, w) + + path_file4 = file_path.replace('-mod' + str(self.modality), '-mod4') + img4 = nib.load(path_file4).get_fdata() # (d, h, w) + # normalization + normalized_img4 = self.irm_min_max_preprocess(img4.transpose(1, 2, 0)) # (h, w, d) + normalized_img4 = normalized_img4.transpose(2, 0, 1) # (d, h, w) + + # image resizing for memory issues + if normalized_img1.size > 2433600: # 100*156*156 = 2433600 + normalized_img_resized1, _ = self.resize_manual(normalized_img1, normalized_img1) + normalized_img_resized2, _ = self.resize_manual(normalized_img2, normalized_img2) + normalized_img_resized3, _ = self.resize_manual(normalized_img3, normalized_img3) + normalized_img_resized4, _ = self.resize_manual(normalized_img4, normalized_img4) + img_resized, _ = self.resize_manual(img, img) + + normalized_img_resized = np.stack((normalized_img_resized1, normalized_img_resized2, + normalized_img_resized3, normalized_img_resized4)) # (c=4, d, h, w) + else: + normalized_img_resized = np.stack((normalized_img1, normalized_img2, + normalized_img3, normalized_img4)) # (c=4, d, h, w) + img_resized = img + + normalized_img_resized = torch.from_numpy(normalized_img_resized) # (c=4, d, h, w) + + else: + # normalization + normalized_img = self.irm_min_max_preprocess(img.transpose(1, 2, 0)) # (h, w, d) + normalized_img = normalized_img.transpose(2, 0, 1) # (d, h, w) + + # image resizing for memory issues + if normalized_img.size > 2433600: # 100*156*156 = 2433600 + normalized_img_resized, _ = self.resize_manual(normalized_img, normalized_img) + img_resized, _ = self.resize_manual(img, img) + + else: + normalized_img_resized = normalized_img + img_resized = img + + normalized_img_resized = torch.from_numpy(normalized_img_resized) # (d, h, w) + normalized_img_resized = torch.unsqueeze(normalized_img_resized, 0) # (c=1, d, h, w) + + scaling_ratio = img.shape * np.array(img_nifti.header.get_zooms()) / img_resized.shape + scaling_ratio = scaling_ratio.astype(np.float32) + + normalized_img_resized = normalized_img_resized.unsqueeze(0) # (n=1, c, d, h, w) + + return normalized_img_resized, img_nifti, img_resized, scaling_ratio + + + + def irm_min_max_preprocess(self, image, low_perc=1, high_perc=99): + """Main pre-processing function used for the challenge (seems to work the best). + Remove outliers voxels first, then min-max scale. + Warnings + -------- + This will not do it channel wise!! + """ + non_zeros = image > 0 + low, high = np.percentile(image[non_zeros], [low_perc, high_perc]) + image = np.clip(image, low, high) + + min_ = np.min(image) + max_ = np.max(image) + scale = max_ - min_ + image = (image - min_) / scale + + return image + + + def data_normalization_mean_std(self, image): + """subtarcting mean and std for each individual patient and modality + mean and std only over the tumor region + Parameters + ---------- + image: numpy array + The raw input image + Returns + ------- + normalized_img: numpy array + The normalized image + """ + mean = image[image > 0].mean() + std = image[image > 0].std() + + if self.outzero_normalization: + image[image < 0] = -1000 + + normalized_img = (image - mean) / std + + if self.outzero_normalization: + normalized_img[normalized_img < -100] = 0 + + return normalized_img + + + def resize_manual(self, img, gt): + """Downsampling of the image and its label. + Parameters + ---------- + img: numpy array + input image + gt: numpy array + input label + Returns + ------- + img: numpy array + downsampled image + gt: numpy array + downsampled label + """ + resize_ratio = np.divide(tuple(self.params['Network']['resize_shape']), img.shape) + img = zoom(img, resize_ratio, order=2) + gt = zoom(gt, resize_ratio, order=0) + return img, gt \ No newline at end of file diff --git a/data/data_utils.py b/data/data_utils.py new file mode 100644 index 0000000..48c62e1 --- /dev/null +++ b/data/data_utils.py @@ -0,0 +1,147 @@ +""" +Created on March 7, 2022. +data_utils.py + +@author: Soroosh Tayebi Arasteh +https://github.com/tayebiarasteh/ +""" + + +import numpy as np +import os +import pdb +import nibabel as nib +import pandas as pd +from tqdm import tqdm + +import SimpleITK as sitk +import shutil +from multiprocessing import Pool +from collections import OrderedDict + + + +def weight_creator(file_base_dir="/home/soroosh/Documents/datasets/BraTS20/", label_num=2, modality=2): + """Inverse class frequency weight creator based on the training data. + Note that all the label files should have the same class numbers + and the numbers should start from 0 and be integer. + + Parameters + ---------- + label_num: int + 1: tumor core (smallest) + 2: whole tumor (biggest) + 3: Enhancing tumor + + modality: int + modality of the MR sequence + 1: T1 + 2: T1Gd + 3: T2 + 4: T2-FLAIR + + Returns + ------- + weight: list + a list including the inverted class frequencies based on the training data + """ + chosen_df = pd.read_csv(os.path.join(file_base_dir, "brats20_master_list.csv"), sep=',') + chosen_df = chosen_df[chosen_df['soroosh_split'] == 'train'] + + if int(modality) == 1: + file_base_dir = os.path.join(file_base_dir, 'T1') + elif int(modality) == 2: + file_base_dir = os.path.join(file_base_dir, 'T1Gd') + elif int(modality) == 3: + file_base_dir = os.path.join(file_base_dir, 'T2') + elif int(modality) == 4: + file_base_dir = os.path.join(file_base_dir, 'T2-FLAIR') + + file_list = list(chosen_df['pat_num']) + path_pat = os.path.join(file_base_dir, 'pat' + str(int(file_list[0])).zfill(3)) + label_path = os.path.join(path_pat, 'pat' + str(int(file_list[0])).zfill(3) + '-seg-label' + str(int(label_num)) + '.nii.gz') + sums_array = np.zeros_like(np.unique(nib.load(label_path).get_fdata())) + + for idx in range(len(file_list)): + + path_pat = os.path.join(file_base_dir, 'pat' + str(int(file_list[idx])).zfill(3)) + label_path = os.path.join(path_pat, 'pat' + str(int(file_list[idx])).zfill(3) + '-seg-label' + str(int(label_num)) + '.nii.gz') + label = nib.load(label_path).get_fdata() + + for classnum in range(len(np.unique(label))): + sums_array[classnum] += (label == classnum).sum() + + total = sums_array.sum() + + tempweight = total / sums_array + final_factor = sum(tempweight) + weight = tempweight / final_factor + print(weight) + + return weight + + + + +def mean_std_calculator(file_base_dir="/home/soroosh/Documents/datasets/BraTS20/", modality=2): + """ + + Parameters + ---------- + modality: int + modality of the MR sequence + 1: T1 + 2: T1Gd + 3: T2 + 4: T2-FLAIR + + Returns + ------- + weight: list + a list including the inverted class frequencies based on the training data + """ + chosen_df = pd.read_csv(os.path.join(file_base_dir, "brats20_master_list.csv"), sep=',') + chosen_df = chosen_df[chosen_df['soroosh_split'] == 'train'] + + if int(modality) == 1: + file_base_dir = os.path.join(file_base_dir, 'T1') + elif int(modality) == 2: + file_base_dir = os.path.join(file_base_dir, 'T1Gd') + elif int(modality) == 3: + file_base_dir = os.path.join(file_base_dir, 'T2') + elif int(modality) == 4: + file_base_dir = os.path.join(file_base_dir, 'T2-FLAIR') + + file_list = list(chosen_df['pat_num']) + stackk = np.array([]) + + maxx = 0 + minn = 0 + for idx in tqdm(range(len(file_list))): + + path_pat = os.path.join(file_base_dir, 'pat' + str(int(file_list[idx])).zfill(3)) + img_path = os.path.join(path_pat, 'pat' + str(int(file_list[idx])).zfill(3) + '-mod' + str(modality) + '.nii.gz') + img = nib.load(img_path).get_fdata() + + stackk = np.hstack((stackk, img[img > 0 ])) + # stackk = np.hstack((stackk, img.flatten())) + + final_mean = stackk.mean() + final_std = stackk.std() + print(final_mean, final_std) + # if maxx < img.max(): + # maxx = img.max() + # if minn > img.min(): + # minn = img.min() + # print(maxx, minn) + + return final_mean, final_std + + + + + + +if __name__=='__main__': + weight = weight_creator(file_base_dir="/home/soroosh/Documents/datasets/BraTS20/", label_num=2, modality=2) + # final_mean, final_std = mean_std_calculator(file_base_dir="/home/soroosh/Documents/datasets/BraTS20/", modality=2) diff --git a/main_3D_brats.py b/main_3D_brats.py new file mode 100644 index 0000000..f4e76f8 --- /dev/null +++ b/main_3D_brats.py @@ -0,0 +1,260 @@ +""" +Created on March 4, 2022. +main_3D_brats.py + +@author: Soroosh Tayebi Arasteh +https://github.com/tayebiarasteh/ +""" + +import pdb +import torch +import os +from torch.utils.data import Dataset +from torch.nn import CrossEntropyLoss +from torchvision import models +import numpy as np +from tqdm import tqdm +import nibabel as nib + +from config.serde import open_experiment, create_experiment, delete_experiment, write_config +from Train_Valid_brats import Training +from Prediction_brats import Prediction +from data.data_provider_brats import data_loader_3D, data_loader_without_label_3D +from models.UNet3D import UNet3D +from models.EDiceLoss_loss import EDiceLoss + +import warnings +warnings.filterwarnings('ignore') + + + +def main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", valid=False, + resume=False, augment=False, experiment_name='name', modality=2): + """Main function for training + validation for directly 3d-wise + + Parameters + ---------- + global_config_path: str + always global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml" + + valid: bool + if we want to do validation + + resume: bool + if we are resuming training on a model + + augment: bool + if we want to have data augmentation during training + + experiment_name: str + name of the experiment, in case of resuming training. + name of new experiment, in case of new training. + + modality: int + modality of the MR sequence + 1: T1 + 2: T1Gd + 3: T2 + 4: T2-FLAIR + """ + if resume == True: + params = open_experiment(experiment_name, global_config_path) + else: + params = create_experiment(experiment_name, global_config_path) + cfg_path = params["cfg_path"] + + # Changeable network parameters + model = UNet3D(n_out_classes=3) # for multi label + + loss_function = EDiceLoss # for multi label + optimizer = torch.optim.Adam(model.parameters(), lr=float(params['Network']['lr']), + weight_decay=float(params['Network']['weight_decay']), amsgrad=params['Network']['amsgrad']) + + train_dataset = data_loader_3D(cfg_path=cfg_path, mode='train', modality=modality) + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=params['Network']['batch_size'], + pin_memory=True, drop_last=True, shuffle=True, num_workers=10) + if valid: + valid_dataset = data_loader_3D(cfg_path=cfg_path, mode='valid', modality=modality) + valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=params['Network']['batch_size'], + pin_memory=True, drop_last=True, shuffle=False, num_workers=5) + else: + valid_loader = None + + trainer = Training(cfg_path, num_epochs=params['num_epochs'], resume=resume, augment=augment) + if resume == True: + trainer.load_checkpoint(model=model, optimiser=optimizer, loss_function=loss_function, weight=None) + else: + trainer.setup_model(model=model, optimiser=optimizer, + loss_function=loss_function, weight=None) + trainer.train_epoch(train_loader=train_loader, valid_loader=valid_loader) + + + + +def main_test_federated_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", + experiment_name='main_federated', modality=2): + """Main function for multi label prediction + + Parameters + ---------- + experiment_name: str + name of the experiment to be loaded. + """ + params = open_experiment(experiment_name, global_config_path) + cfg_path = params['cfg_path'] + model = UNet3D(n_out_classes=2) + + # Initialize prediction + predictor = Prediction(cfg_path) + predictor.setup_model_federated(model=model) + + # Generate test set + test_dataset = data_loader_3D(cfg_path=cfg_path, mode='test') + test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=params['Network']['batch_size'], + pin_memory=True, drop_last=True, shuffle=False, num_workers=5) + + + test_F1, test_specifity, test_sensitivity, test_precision, test_AUCROC = predictor.evaluate_3D(test_loader) + + ### evaluation metrics + print(f'\n\t F1 (Dice score): {test_F1 * 100:.2f}%' + f' | AUC ROC: {test_AUCROC * 100:.2f}% | specifity: {test_specifity * 100:.2f}%' + f' | recall (sensitivity): {test_sensitivity * 100:.2f}% | precision: {test_precision * 100:.2f}%\n') + print('------------------------------------------------------' + '----------------------------------') + + # saving the training and validation stats + msg = f'----------------------------------------------------------------------------------------\n' \ + f'\nF1 (Dice score): {test_F1 * 100:.2f}% ' \ + f' | AUC ROC: {test_AUCROC * 100:.2f}% | specifity: {test_specifity * 100:.2f}%' \ + f' | recall (sensitivity): {test_sensitivity * 100:.2f}% | precision: {test_precision * 100:.2f}%\n\n' + + with open(os.path.join(params['target_dir'], params['stat_log_path']) + '/federatedtest_results', 'a') as f: + f.write(msg) + + + +def main_evaluate_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", + experiment_name='name', modality=2, tta=False): + """Evaluation (for local models) for all the images using the labels and calculating metrics. + + Parameters + ---------- + experiment_name: str + name of the experiment to be loaded. + """ + params = open_experiment(experiment_name, global_config_path) + cfg_path = params['cfg_path'] + model = UNet3D(n_out_classes=3) + + # Initialize prediction + predictor = Prediction(cfg_path) + predictor.setup_model(model=model) + + # Generate test set + valid_dataset = data_loader_3D(cfg_path=cfg_path, mode='test') + valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=params['Network']['batch_size'], + pin_memory=True, drop_last=True, shuffle=False, num_workers=5) + + if tta: + test_F1, test_accuracy, test_specifity, test_sensitivity, test_precision = predictor.evaluate_3D_tta(valid_loader) + else: + test_F1, test_accuracy, test_specifity, test_sensitivity, test_precision = predictor.evaluate_3D(valid_loader) + + ### evaluation metrics + print(f'\n\t F1 (Dice score): {test_F1.mean().item() * 100:.2f}% | accuracy: {test_accuracy.mean().item() * 100:.2f}%' + f' | specifity: {test_specifity.mean().item() * 100:.2f}%' + f' | recall (sensitivity): {test_sensitivity.mean().item() * 100:.2f}% | precision: {test_precision.mean().item() * 100:.2f}%\n') + + print('Individual F1 scores:') + print(f'Class 1: {test_F1[0].item() * 100:.2f}%') + print(f'Class 2: {test_F1[1].item() * 100:.2f}%') + print(f'Class 3: {test_F1[2].item() * 100:.2f}%\n') + print('------------------------------------------------------' + '----------------------------------') + + # saving the training and validation stats + msg = f'----------------------------------------------------------------------------------------\n' \ + f'\nF1 (Dice score): {test_F1.mean().item() * 100:.2f}% | accuracy: {test_accuracy.mean().item() * 100:.2f}% ' \ + f' | specifity: {test_specifity.mean().item() * 100:.2f}%' \ + f' | recall (sensitivity): {test_sensitivity.mean().item() * 100:.2f}% | precision: {test_precision.mean().item() * 100:.2f}%\n\n' \ + f' | F1 class 1: {test_F1[0].item() * 100:.2f}% | F1 class 2: {test_F1[1].item() * 100:.2f}% | F1 class 3: {test_F1[2].item() * 100:.2f}%\n\n' + + with open(os.path.join(params['target_dir'], params['stat_log_path']) + '/test_results', 'a') as f: + f.write(msg) + + + + + +def main_predict_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", + experiment_name='name', modality=2, tta=False): + """Prediction without evaluation for all the images. + + Parameters + ---------- + experiment_name: str + name of the experiment to be loaded. + """ + params = open_experiment(experiment_name, global_config_path) + cfg_path = params['cfg_path'] + model = UNet3D(n_out_classes=3) + + # Initialize prediction + predictor = Prediction(cfg_path) + predictor.setup_model(model=model) + + # Generate test set + test_dataset = data_loader_without_label_3D(cfg_path=cfg_path, mode='test') + + for idx in tqdm(range(len(test_dataset.file_path_list))): + path_pat = os.path.join(test_dataset.file_base_dir, 'pat' + str(test_dataset.file_path_list[idx]).zfill(3)) + path_file = os.path.join(path_pat, 'pat' + str(test_dataset.file_path_list[idx]).zfill(3) + '-mod' + str( + test_dataset.modality) + '.nii.gz') + + x_input, x_input_nifti, img_resized, scaling_ratio = test_dataset.provide_test_without_label(file_path=path_file) + + if tta: + output_sigmoided = predictor.predict_3D_tta(x_input) # (d,h,w) + else: + output_sigmoided = predictor.predict_3D(x_input) # (d,h,w) + output_sigmoided = output_sigmoided.cpu().detach().numpy() + + x_input_nifti.header['pixdim'][1:4] = scaling_ratio + x_input_nifti.header['dim'][1:4] = np.array(img_resized.shape) + x_input_nifti.affine[0, 0] = scaling_ratio[0] + x_input_nifti.affine[1, 1] = scaling_ratio[1] + x_input_nifti.affine[2, 2] = scaling_ratio[2] + + # segmentation = nib.Nifti1Image(output_sigmoided[0,0], affine=x_input_nifti.affine, header=x_input_nifti.header) + # nib.save(segmentation, os.path.join(params['target_dir'], params['output_data_path'], os.path.basename(path_file).replace('.nii.gz', '-downsampled' + str(test_dataset.label_num) + '-label' + '.nii.gz'))) + segmentation = nib.Nifti1Image(output_sigmoided[0,0], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(segmentation, os.path.join(params['target_dir'], params['output_data_path'], os.path.basename(path_file).replace('.nii.gz', '-downsampled1-label.nii.gz'))) + segmentation = nib.Nifti1Image(output_sigmoided[0,1], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(segmentation, os.path.join(params['target_dir'], params['output_data_path'], os.path.basename(path_file).replace('.nii.gz', '-downsampled2-label.nii.gz'))) + segmentation = nib.Nifti1Image(output_sigmoided[0,2], affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(segmentation, os.path.join(params['target_dir'], params['output_data_path'], os.path.basename(path_file).replace('.nii.gz', '-downsampled3-label.nii.gz'))) + input_img = nib.Nifti1Image(img_resized, affine=x_input_nifti.affine, header=x_input_nifti.header) + nib.save(input_img, os.path.join(params['target_dir'], params['output_data_path'], + os.path.basename(path_file).replace('.nii.gz', '-downsampled-image.nii.gz'))) + pdb.set_trace() + + + + + + + + + + + + +if __name__ == '__main__': + # delete_experiment(experiment_name='4levelunet24_flip_gamma_AWGN_zoomin_central_full_lr1e4_80_80_80', global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml") + main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", + valid=True, resume=False, augment=True, experiment_name='4levelunet24_flip_gamma_AWGN_zoomin_central_full_lr1e4_80_80_80') + # main_evaluate_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", + # experiment_name='multimod_24_4level_flip_gamma_AWGN_blur_cropped_minmax_central_full_lr1e4_80_80_100', tta=False) + # main_predict_3D(global_config_path="/home/soroosh/Documents/Repositories/federated_he/config/config.yaml", + # experiment_name='4levelunet24_flip_gamma_AWGN_zoomin_central_full_lr3e4_80_80_100', tta=False) diff --git a/models/Diceloss.py b/models/Diceloss.py new file mode 100644 index 0000000..c61e6c1 --- /dev/null +++ b/models/Diceloss.py @@ -0,0 +1,88 @@ +""" +Created on March 11, 2022. +Diceloss.py + +@ref: https://github.com/hubutui/DiceLoss-PyTorch + +Dice loss (both multi class (and label) & binary). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + + +class BinaryDiceLoss(nn.Module): + """Dice loss of binary class + Args: + smooth: A float number to smooth loss, and avoid NaN error, default: 1 + p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 + predict: A tensor of shape [N, *] + target: A tensor of shape same with predict + reduction: Reduction method to apply, return mean over batch if 'mean', + return sum if 'sum', return a tensor of shape [N,] if 'none' + Returns: + Loss tensor according to arg reduction + Raise: + Exception if unexpected reduction + """ + def __init__(self, smooth=1, p=2, reduction='mean'): + super(BinaryDiceLoss, self).__init__() + self.smooth = smooth + self.p = p + self.reduction = reduction + + def forward(self, predict, target): + assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" + predict = predict.contiguous().view(predict.shape[0], -1) + target = target.contiguous().view(target.shape[0], -1) + + num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth + den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth + + loss = 1 - num / den + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + elif self.reduction == 'none': + return loss + else: + raise Exception('Unexpected reduction {}'.format(self.reduction)) + + +class DiceLoss(nn.Module): + """Dice loss, need one hot encode input + Args: + weight: An array of shape [num_classes,] + ignore_index: class index to ignore + predict: A tensor of shape [N, C, *] + target: A tensor of same shape with predict + other args pass to BinaryDiceLoss + Return: + same as BinaryDiceLoss + """ + def __init__(self, weight=None, ignore_index=None, **kwargs): + super(DiceLoss, self).__init__() + self.kwargs = kwargs + self.weight = weight + self.ignore_index = ignore_index + + def forward(self, predict, target): + assert predict.shape == target.shape, 'predict & target shape do not match' + dice = BinaryDiceLoss(**self.kwargs) + total_loss = 0 + predict = F.softmax(predict, dim=1) + + for i in range(target.shape[1]): + if i != self.ignore_index: + dice_loss = dice(predict[:, i], target[:, i]) + if self.weight is not None: + assert self.weight.shape[0] == target.shape[1], \ + 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) + dice_loss *= self.weights[i] + total_loss += dice_loss + + return total_loss/target.shape[1] \ No newline at end of file diff --git a/models/EDiceLoss_loss.py b/models/EDiceLoss_loss.py new file mode 100644 index 0000000..0098181 --- /dev/null +++ b/models/EDiceLoss_loss.py @@ -0,0 +1,83 @@ +""" +Created on March 10, 2022. +EDiceLoss_loss.py + +@ref: https://github.com/lescientifik/open_brats2020 + +Dice loss (multi label) tailored to Brats needs. +""" +import torch +import torch.nn as nn + + + +class EDiceLoss(nn.Module): + """Dice loss tailored to Brats need. + """ + + def __init__(self, do_sigmoid=True): + super(EDiceLoss, self).__init__() + self.do_sigmoid = do_sigmoid + self.labels = ["ET", "TC", "WT"] + self.setup_cuda() + + def binary_dice(self, inputs, targets, label_index, metric_mode=False): + smooth = 1. + if self.do_sigmoid: + inputs = torch.sigmoid(inputs) + + if metric_mode: + inputs = inputs > 0.5 + if targets.sum() == 0: + print(f"No {self.labels[label_index]} for this patient") + if inputs.sum() == 0: + return torch.tensor(1., device=self.device) + else: + return torch.tensor(0., device=self.device) + # Threshold the pred + intersection = EDiceLoss.compute_intersection(inputs, targets) + if metric_mode: + dice = (2 * intersection) / ((inputs.sum() + targets.sum()) * 1.0) + else: + dice = (2 * intersection + smooth) / (inputs.pow(2).sum() + targets.pow(2).sum() + smooth) + if metric_mode: + return dice + return 1 - dice + + @staticmethod + def compute_intersection(inputs, targets): + intersection = torch.sum(inputs * targets) + return intersection + + def forward(self, inputs, target): + dice = 0 + for i in range(target.size(1)): + dice = dice + self.binary_dice(inputs[:, i, ...], target[:, i, ...], i) + + final_dice = dice / target.size(1) + return final_dice + + def metric(self, inputs, target): + dices = [] + for j in range(target.size(0)): + dice = [] + for i in range(target.size(1)): + dice.append(self.binary_dice(inputs[j, i], target[j, i], i, True)) + dices.append(dice) + return dices + + + def setup_cuda(self, cuda_device_id=0): + """setup the device. + + Parameters + ---------- + cuda_device_id: int + cuda device id + """ + if torch.cuda.is_available(): + torch.backends.cudnn.fastest = True + torch.cuda.set_device(cuda_device_id) + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') diff --git a/models/UNet3D.py b/models/UNet3D.py new file mode 100644 index 0000000..aa3ee88 --- /dev/null +++ b/models/UNet3D.py @@ -0,0 +1,229 @@ +""" +Created on March 4, 2022. +UNet3D.py + +@author: Soroosh Tayebi Arasteh +https://github.com/tayebiarasteh/ +""" + + +import torch +import torch.nn as nn +import pdb +import torch.nn.functional as F + + +class UNet3D(nn.Module): + def __init__(self, n_in_channels=4, n_out_classes=3, threelevel=False, firstdim=24, weight_init=True): + """ + Parameters + ---------- + firstdim: int + 16 + 24 + + threelevel: bool + if we want to have 3-level UNet or 4-level + + weight_init: bool + if we want to initialize the biases with zero and weights with He initialization + """ + super(UNet3D, self).__init__() + self.threelevel = threelevel + + self.input_block = inconv(n_in_channels, firstdim, weight_init) + + # 3-level + if self.threelevel: + self.down1 = down_one(firstdim, firstdim * 2, weight_init) + self.down2 = down(firstdim * 2, firstdim * 2, weight_init) + self.up2 = up(firstdim * 4, firstdim, weight_init) + self.up3 = up_one(firstdim * 2, firstdim, weight_init) + + # 4-level + else: + self.down1 = down_one(firstdim, firstdim * 2, weight_init) + self.down2 = down(firstdim * 2, firstdim * 4, weight_init) + self.down3 = down(firstdim * 4, firstdim * 4, weight_init) + self.up1 = up(firstdim * 8, firstdim * 2, weight_init) + self.up2 = up(firstdim * 4, firstdim, weight_init) + self.up3 = up_one(firstdim * 2, firstdim, weight_init) + + self.output_block = outconv(firstdim, n_out_classes, weight_init) + + + def forward(self, input_tensor): + first_output = self.input_block(input_tensor) + + # 3-level + if self.threelevel: + down1_output = self.down1(first_output) + down2_output = self.down2(down1_output) + up2_output = self.up2(down2_output, down1_output) + up3_output = self.up3(up2_output, first_output) + + # 4-level + else: + down1_output = self.down1(first_output) + down2_output = self.down2(down1_output) + down3_output = self.down3(down2_output) + up1_output = self.up1(down3_output, down2_output) + up2_output = self.up2(up1_output, down1_output) + up3_output = self.up3(up2_output, first_output) + + # unpadding + if input_tensor.shape[-1] != up3_output.shape[-1]: + diff = up3_output.shape[-1] - input_tensor.shape[-1] + up3_output = up3_output[..., :-diff] + if input_tensor.shape[-2] != up3_output.shape[-2]: + diff2 = up3_output.shape[-2] - input_tensor.shape[-2] + up3_output = up3_output[:, :, :, :-diff2] + if input_tensor.shape[-3] != up3_output.shape[-3]: + diff3 = up3_output.shape[-3] - input_tensor.shape[-3] + up3_output = up3_output[:, :, :-diff3] + + output_tensor = self.output_block(up3_output) + + return output_tensor + + + +class double_conv(nn.Module): + def __init__(self, in_ch, out_ch, weight_init): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1), # (n,c,d,h,w) + nn.ReLU(inplace=False), + nn.BatchNorm3d(out_ch), + nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1), + nn.ReLU(inplace=False), + nn.BatchNorm3d(out_ch)) + + if weight_init: + for idx in range(len(self.conv)): + for name, param in self.conv[idx].named_parameters(): + if 'bias' in name: + nn.init.constant_(param, 0.0) + elif 'weight' in name: + if isinstance(self.conv[idx], nn.Conv3d) or isinstance(self.conv[idx], nn.Conv2d) or isinstance( + self.conv[idx], nn.ConvTranspose2d) or isinstance(self.conv[idx], nn.ConvTranspose3d): + nn.init.kaiming_normal_(param, a=1e-2) + + + def forward(self, input_tensor): + output_tensor = self.conv(input_tensor) + return output_tensor + + +#Input Block +class inconv(nn.Module): + def __init__(self, in_ch, out_ch, weight_init): + super(inconv, self).__init__() + self.in_double_conv1 = double_conv(in_ch, out_ch, weight_init) + + def forward(self, input_tensor): + #Apply input_tensor on object from init function and return the output_tensor + output_tensor = self.in_double_conv1(input_tensor) + return output_tensor + +#Down 2, 3 Block +class down(nn.Module): + def __init__(self, in_ch, out_ch, weight_init): + super(down, self).__init__() + self.maxpool1 = nn.MaxPool3d(kernel_size = 2) + self.down_double_conv1 = double_conv(in_ch, out_ch, weight_init) + + def forward(self, input_tensor): + #Apply input_tensor on object from init function and return the output_tensor + input_tensor = self.maxpool1(input_tensor) + output_tensor = self.down_double_conv1(input_tensor) + return output_tensor + + +#Down 1 Block +class down_one(nn.Module): + def __init__(self, in_ch, out_ch, weight_init): + super(down_one, self).__init__() + self.maxpool1 = nn.MaxPool3d(kernel_size = 2) + self.down_double_conv1 = double_conv(in_ch, out_ch, weight_init) + + def forward(self, input_tensor): + #Apply input_tensor on object from init function and return the output_tensor + input_tensor = self.maxpool1(input_tensor) + output_tensor = self.down_double_conv1(input_tensor) + return output_tensor + + + +#Up 2, 3 Block +class up(nn.Module): + def __init__(self, in_ch, out_ch, weight_init): + super(up, self).__init__() + self.upsample1 = nn.Upsample(scale_factor=2) + self.up_double_conv1 = double_conv(in_ch, out_ch, weight_init) + + def forward(self, input_tensor_1, input_tensor_2): + input_tensor_1 = self.upsample1(input_tensor_1) + + # zero-padding + if input_tensor_1.shape[-1] != input_tensor_2.shape[-1]: + diff = input_tensor_2.shape[-1] - input_tensor_1.shape[-1] + input_tensor_1 = F.pad(input_tensor_1, (0, diff), "constant", 0) + if input_tensor_1.shape[-2] != input_tensor_2.shape[-2]: + diff2 = input_tensor_2.shape[-2] - input_tensor_1.shape[-2] + input_tensor_1 = F.pad(input_tensor_1, (0, 0, 0, diff2), "constant", 0) + if input_tensor_1.shape[-3] != input_tensor_2.shape[-3]: + diff3 = input_tensor_2.shape[-3] - input_tensor_1.shape[-3] + input_tensor_1 = F.pad(input_tensor_1, (0, 0, 0, 0, 0, diff3), "constant", 0) + + #Concatenation of the upsampled result and input_tensor_2 + input_tensor = torch.cat((input_tensor_1 , input_tensor_2), 1) + output_tensor = self.up_double_conv1(input_tensor) + return output_tensor + + +#Up 1 Block +class up_one(nn.Module): + def __init__(self, in_ch, out_ch, weight_init): + super(up_one, self).__init__() + self.upsample1 = nn.Upsample(scale_factor=2) + self.up_double_conv1 = double_conv(in_ch, out_ch, weight_init) + + def forward(self, input_tensor_1, input_tensor_2): + input_tensor_1 = self.upsample1(input_tensor_1) + + # zero-padding + if input_tensor_1.shape[-1] != input_tensor_2.shape[-1]: + diff = input_tensor_2.shape[-1] - input_tensor_1.shape[-1] + input_tensor_1 = F.pad(input_tensor_1, (0, diff), "constant", 0) + if input_tensor_1.shape[-2] != input_tensor_2.shape[-2]: + diff2 = input_tensor_2.shape[-2] - input_tensor_1.shape[-2] + input_tensor_1 = F.pad(input_tensor_1, (0, 0, 0, diff2), "constant", 0) + if input_tensor_1.shape[-3] != input_tensor_2.shape[-3]: + diff3 = input_tensor_2.shape[-3] - input_tensor_1.shape[-3] + input_tensor_1 = F.pad(input_tensor_1, (0, 0, 0, 0, 0, diff3), "constant", 0) + + #Concatenation of the upsampled result and input_tensor_2 + input_tensor = torch.cat((input_tensor_1 , input_tensor_2), 1) + output_tensor = self.up_double_conv1(input_tensor) + return output_tensor + + + +#Out Block +class outconv(nn.Module): + def __init__(self, in_ch, out_ch, weight_init): + super(outconv, self).__init__() + self.conv_out = nn.Conv3d(in_ch, out_ch, kernel_size=1) + if weight_init: + for name, param in self.conv_out.named_parameters(): + if 'bias' in name: + nn.init.constant_(param, 0.0) + elif 'weight' in name: + if isinstance(self.conv_out, nn.Conv3d) or isinstance(self.conv_out, nn.Conv2d) or isinstance( + self.conv_out, nn.ConvTranspose2d) or isinstance(self.conv_out, nn.ConvTranspose3d): + nn.init.kaiming_normal_(param, a=1e-2) + + def forward(self, input_tensor): + output_tensor = self.conv_out(input_tensor) + return output_tensor \ No newline at end of file