-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 444bb20
Showing
44 changed files
with
2,938 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import os | ||
|
||
import torch.nn as nn | ||
from torch import optim | ||
from torch.utils.data import DataLoader | ||
|
||
from components import networks3D | ||
from utils.Diag_pretraining import train_data | ||
from utils.UnpairedDataset import UnpairedDataset | ||
from utils.utils import mkdir | ||
|
||
if __name__ == '__main__': | ||
# args definition | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--gpu_ids', default='7', help='gpu ids: e.g. 0') | ||
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers') | ||
parser.add_argument('--batch_size', type=int, default=4, help='input batch size') | ||
parser.add_argument('--lr', type=float, default=0.01, help='initial learning rate for adam') | ||
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs') | ||
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') | ||
parser.add_argument('--gamma', type=float, default=0.9, help='basic gamma value for exponentialLR') | ||
parser.add_argument('--init_type', type=str, default='normal', | ||
help='network initialization [normal|xavier|kaiming|orthogonal]') | ||
parser.add_argument('--init_gain', type=float, default=0.02, | ||
help='scaling factor for normal, xavier and orthogonal.') | ||
parser.add_argument('--use_early_stop', action='store_true', help='use early stop') | ||
parser.add_argument('--patience', type=int, default=5, | ||
help='How long to wait after last time validation loss improved.') | ||
parser.add_argument('--checkpoints_dir', type=str, default='/data/chwang/Log/ShareGAN', | ||
help='models are saved here') | ||
parser.add_argument('--name', type=str, default='DiagNet', help='saving name') | ||
parser.add_argument('--load_size', default=256, help='Size of the original image') | ||
parser.add_argument('--crop_size', default=128, help='Size of the patches extracted from the image') | ||
parser.add_argument('--save_freq', type=int, default=10, | ||
help='frequency of saving checkpoints at the end of epochs') | ||
args = parser.parse_args() | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids | ||
|
||
model = networks3D.define_Cls(2, args.init_type, args.init_gain, args.gpu_ids) | ||
epochs = args.n_epochs | ||
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) | ||
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma) | ||
criterion = nn.CrossEntropyLoss() | ||
train_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="train", load_size=args.load_size, | ||
crop_size=args.crop_size) | ||
valid_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="valid", load_size=args.load_size, | ||
crop_size=args.crop_size) | ||
print('length train list:', len(train_set)) | ||
print('length valid list:', len(valid_set)) | ||
train_loader = DataLoader(train_set, | ||
batch_size=args.batch_size, | ||
num_workers=args.workers, | ||
shuffle=True) | ||
valid_loader = DataLoader(valid_set, | ||
batch_size=args.batch_size, | ||
num_workers=args.workers, | ||
shuffle=False) | ||
save_dir = args.checkpoints_dir + '/' + args.name + '/' | ||
mkdir(save_dir) | ||
train_data(model, train_loader, valid_loader, epochs, optimizer, scheduler, criterion, args.use_early_stop, | ||
args.patience, args.gpu_ids, save_dir, args.save_freq) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import argparse | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from sklearn.metrics import f1_score, recall_score, roc_auc_score | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
|
||
from components import networks3D | ||
from utils.UnpairedDataset import UnpairedDataset | ||
|
||
|
||
def evaluate_diagNetwork(model, valid_dataloaders): | ||
"""Evaluate a generator. | ||
Parameters: | ||
generator - - : (nn.Module) neural network generating PET images | ||
train_loader - - : (dataloader) the training loader | ||
test_loader - - : (dataloader) the testing loader | ||
Returns: | ||
df - - : (dataframe) the dataframe of the different Sets | ||
""" | ||
criterion = nn.CrossEntropyLoss() | ||
val_correct_sum = 0 | ||
val_simple_cnt = 0 | ||
val_loss = 0 | ||
y_val_true = [] | ||
y_val_pred = [] | ||
val_prob_all = [] | ||
val_label_all = [] | ||
with torch.no_grad(): | ||
model.eval() | ||
for ii, (images, _, labels) in enumerate(tqdm(valid_dataloaders)): | ||
images, labels = images.squeeze(1).cuda(), labels.cuda() | ||
outputs, _, _, _, _ = model(images) | ||
val_loss += criterion(outputs, labels).item() | ||
_, val_predicted = torch.max(outputs.data, 1) | ||
val_correct_sum += (labels.data == val_predicted).sum().item() | ||
val_simple_cnt += labels.size(0) | ||
y_val_true.extend(np.ravel(np.squeeze(labels.cpu().detach().numpy())).tolist()) | ||
y_val_pred.extend(np.ravel(np.squeeze(val_predicted.cpu().detach().numpy())).tolist()) | ||
val_prob_all.extend(outputs[:, | ||
1].cpu().detach().numpy()) | ||
val_label_all.extend(labels.cpu()) | ||
|
||
val_loss = val_loss / len(valid_dataloaders) | ||
val_acc = val_correct_sum / val_simple_cnt | ||
val_f1_score = f1_score(y_val_true, y_val_pred, average='weighted') | ||
val_recall = recall_score(y_val_true, y_val_pred, average='weighted') | ||
val_spe = recall_score(y_val_true, y_val_pred, pos_label=0, average='binary') | ||
val_auc = roc_auc_score(val_label_all, val_prob_all, average='weighted') | ||
|
||
print( | ||
'Val Loss:{:.3f}...'.format(val_loss), | ||
'Val Accuracy:{:.3f}...'.format(val_acc), | ||
'Val AUC:{:.3f}...'.format(val_auc), | ||
'Val F1 Score:{:.3f}'.format(val_f1_score), | ||
'val SPE:{:.3f}...'.format(val_spe), | ||
'Val SEN:{:.3f}...'.format(val_recall) | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
# args definition | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--gpu_ids', default='7', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') | ||
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers') | ||
parser.add_argument('--init_type', type=str, default='normal', | ||
help='network initialization [normal|xavier|kaiming|orthogonal]') | ||
parser.add_argument('--init_gain', type=float, default=0.02, | ||
help='scaling factor for normal, xavier and orthogonal.') | ||
parser.add_argument('--load_path', type=str, default='/data/chwang/Log/ShareGAN/Cls.pth', | ||
help='models are saved here') | ||
parser.add_argument('--load_size', default=256, help='Size of the original image') | ||
parser.add_argument('--crop_size', default=128, help='Size of the patches extracted from the image') | ||
parser.add_argument('--dataset', default="adni", type=str, help='Types of dataset [adni|aibl|nacc]') | ||
args = parser.parse_args() | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids | ||
# test set | ||
test_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="test", load_size=args.load_size, | ||
crop_size=args.crop_size, dataset=args.dataset) | ||
test_loader = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=args.workers, | ||
pin_memory=True) # Here are then fed to the network with a defined batch size | ||
print('length test list:', len(test_set)) | ||
# model definition | ||
print('initialize the model') | ||
model = networks3D.define_Cls(2, args.init_type, args.init_gain, args.gpu_ids) | ||
print('loading state dict from : {}'.format(args.load_path)) | ||
state_dict = torch.load(args.load_path, map_location='cuda') | ||
model.load_state_dict(state_dict) | ||
if len(args.gpu_ids) > 0: | ||
assert (torch.cuda.is_available()) | ||
if len(args.gpu_ids) > 1: | ||
model = torch.nn.DataParallel(model) | ||
model = model.cuda() | ||
|
||
# model evaluation | ||
evaluate_diagNetwork(model, test_loader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import pandas as pd | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
|
||
from components import networks3D | ||
from components.performance_metric import mean_absolute_error, peak_signal_to_noise_ratio, structural_similarity_index | ||
from options.test_options import TestOptions | ||
from utils.UnpairedDataset import UnpairedDataset | ||
|
||
|
||
def evaluate_generator(generator, test_loader, netG): | ||
"""Evaluate a generator. | ||
Parameters: | ||
generator - - : (nn.Module) neural network generating PET images | ||
train_loader - - : (dataloader) the training loader | ||
test_loader - - : (dataloader) the testing loader | ||
Returns: | ||
df - - : (dataframe) the dataframe of the different Sets | ||
""" | ||
res_test = [] | ||
|
||
cuda = True if torch.cuda.is_available() else False | ||
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor | ||
generator.eval() | ||
|
||
with torch.no_grad(): | ||
for i, batch in enumerate(tqdm(test_loader)): | ||
# Inputs MRI and PET | ||
real_mri = batch[0].type(Tensor) | ||
real_pet = batch[1].type(Tensor) | ||
if netG == 'ShareSynNet': | ||
fake_pet = generator(real_mri, alpha=0.) | ||
else: | ||
fake_pet = generator(real_mri) | ||
mae = mean_absolute_error(real_pet, fake_pet).item() | ||
psnr = peak_signal_to_noise_ratio(real_pet, fake_pet).item() | ||
ssim = structural_similarity_index(real_pet, fake_pet).item() | ||
res_test.append([mae, psnr, ssim]) | ||
|
||
df = pd.DataFrame([ | ||
pd.DataFrame(res_test, columns=['MAE', 'PSNR', 'SSIM']).mean().squeeze() | ||
], index=['Test set']).T | ||
return df | ||
|
||
|
||
if __name__ == '__main__': | ||
opt = TestOptions().parse() | ||
|
||
netG = networks3D.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, | ||
not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids) | ||
print('netG:\ttype:{}\tload_path:{}'.format(opt.netG, opt.load_path)) | ||
test_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="test", load_size=opt.load_size, | ||
crop_size=opt.crop_size) | ||
test_loader = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=opt.workers, | ||
pin_memory=True) # Here are then fed to the network with a defined batch siz | ||
|
||
print('lenght test list:', len(test_set)) | ||
|
||
state_dict = torch.load(opt.load_path) | ||
netG.load_state_dict(state_dict) | ||
test_df = evaluate_generator(netG, test_loader, opt.netG) | ||
print(test_df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import os | ||
import random | ||
import time | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from models import create_model | ||
from options.train_options import TrainOptions | ||
from utils.UnpairedDataset import UnpairedDataset | ||
from utils.earlystop import EarlyStopping | ||
from utils.visualizer import Visualizer | ||
|
||
|
||
def seed_torch(seed=1029): | ||
random.seed(seed) | ||
os.environ['PYTHONHASHSEED'] = str(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. | ||
torch.backends.cudnn.benchmark = False | ||
torch.backends.cudnn.deterministic = True | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
# ----- Loading the init options ----- | ||
opt = TrainOptions().parse() | ||
# [option] to seed the seed | ||
# seed_torch(opt.seed) | ||
|
||
# ----- Transformation and Augmentation process for the data ----- | ||
|
||
train_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="train", load_size=opt.load_size, | ||
crop_size=opt.crop_size) | ||
valid_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="valid", load_size=opt.load_size, | ||
crop_size=opt.crop_size) | ||
print('length train list:', len(train_set)) | ||
train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, | ||
pin_memory=True) # Here are then fed to the network with a defined batch size | ||
valid_loader = DataLoader(valid_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, | ||
pin_memory=False) # Here are then fed to the network with a defined batch size | ||
|
||
# initialize the early_stopping object | ||
if opt.use_earlystop: | ||
print('using early stop') | ||
early_stopping = EarlyStopping(patience=opt.patience, verbose=True) | ||
|
||
# ----------------------------------------------------- | ||
model = create_model(opt) # creation of the model | ||
model.setup(opt) | ||
visualizer = Visualizer(opt, train_loader) | ||
total_steps = 0 | ||
|
||
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): | ||
epoch_start_time = time.time() | ||
iter_data_time = time.time() | ||
epoch_iter = 0 | ||
|
||
for i, data in enumerate(train_loader): | ||
iter_start_time = time.time() | ||
if total_steps % opt.print_freq == 0: | ||
t_data = iter_start_time - iter_data_time | ||
total_steps += opt.batch_size | ||
epoch_iter += opt.batch_size | ||
model.set_input(data) | ||
if total_steps % opt.update_step == 0: | ||
model.optimize_parameters(opt.update_step, True) | ||
else: | ||
model.optimize_parameters(opt.update_step, False) | ||
if total_steps % opt.print_freq == 0: | ||
losses = model.get_current_losses() | ||
t = (time.time() - iter_start_time) / opt.batch_size | ||
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) | ||
|
||
if total_steps % opt.save_latest_freq == 0: | ||
print('saving the latest model (epoch %d, total_steps %d)' % | ||
(epoch, total_steps)) | ||
model.save_networks('latest') | ||
|
||
if total_steps % opt.eval_freq == 0: | ||
loss_G_list = [] | ||
with torch.no_grad(): | ||
model.eval() # prep model for evaluation | ||
for i, data in enumerate(valid_loader): | ||
# forward pass: compute predicted outputs by passing inputs to the model | ||
model.set_input(data) | ||
loss_G_list.append(model.get_current_losses()['G']) | ||
|
||
# early_stopping needs the validation loss to check if it has decresed, | ||
# and if it has, it will make a checkpoint of the current model | ||
if opt.use_earlystop: | ||
early_stopping(np.mean(loss_G_list), model, epoch) | ||
if early_stopping.early_stop: | ||
print("Early stopping from iteration") | ||
break | ||
if opt.use_earlystop: | ||
if early_stopping.early_stop: | ||
print("Early stopping from epoch") | ||
break | ||
|
||
iter_data_time = time.time() | ||
|
||
if epoch % opt.save_epoch_freq == 0: | ||
print('saving the model at the end of epoch %d, iters %d' % | ||
(epoch, total_steps)) | ||
model.save_networks('latest') | ||
model.save_networks(epoch) | ||
|
||
print('End of epoch %d / %d \t Time Taken: %d sec' % | ||
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) | ||
model.update_learning_rate() |
Oops, something went wrong.