Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
thibault-wch committed Aug 27, 2023
0 parents commit 444bb20
Show file tree
Hide file tree
Showing 44 changed files with 2,938 additions and 0 deletions.
63 changes: 63 additions & 0 deletions Diag_pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import os

import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

from components import networks3D
from utils.Diag_pretraining import train_data
from utils.UnpairedDataset import UnpairedDataset
from utils.utils import mkdir

if __name__ == '__main__':
# args definition
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_ids', default='7', help='gpu ids: e.g. 0')
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers')
parser.add_argument('--batch_size', type=int, default=4, help='input batch size')
parser.add_argument('--lr', type=float, default=0.01, help='initial learning rate for adam')
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--gamma', type=float, default=0.9, help='basic gamma value for exponentialLR')
parser.add_argument('--init_type', type=str, default='normal',
help='network initialization [normal|xavier|kaiming|orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02,
help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--use_early_stop', action='store_true', help='use early stop')
parser.add_argument('--patience', type=int, default=5,
help='How long to wait after last time validation loss improved.')
parser.add_argument('--checkpoints_dir', type=str, default='/data/chwang/Log/ShareGAN',
help='models are saved here')
parser.add_argument('--name', type=str, default='DiagNet', help='saving name')
parser.add_argument('--load_size', default=256, help='Size of the original image')
parser.add_argument('--crop_size', default=128, help='Size of the patches extracted from the image')
parser.add_argument('--save_freq', type=int, default=10,
help='frequency of saving checkpoints at the end of epochs')
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids

model = networks3D.define_Cls(2, args.init_type, args.init_gain, args.gpu_ids)
epochs = args.n_epochs
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma)
criterion = nn.CrossEntropyLoss()
train_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="train", load_size=args.load_size,
crop_size=args.crop_size)
valid_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="valid", load_size=args.load_size,
crop_size=args.crop_size)
print('length train list:', len(train_set))
print('length valid list:', len(valid_set))
train_loader = DataLoader(train_set,
batch_size=args.batch_size,
num_workers=args.workers,
shuffle=True)
valid_loader = DataLoader(valid_set,
batch_size=args.batch_size,
num_workers=args.workers,
shuffle=False)
save_dir = args.checkpoints_dir + '/' + args.name + '/'
mkdir(save_dir)
train_data(model, train_loader, valid_loader, epochs, optimizer, scheduler, criterion, args.use_early_stop,
args.patience, args.gpu_ids, save_dir, args.save_freq)
102 changes: 102 additions & 0 deletions Diag_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import os

import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import f1_score, recall_score, roc_auc_score
from torch.utils.data import DataLoader
from tqdm import tqdm

from components import networks3D
from utils.UnpairedDataset import UnpairedDataset


def evaluate_diagNetwork(model, valid_dataloaders):
"""Evaluate a generator.
Parameters:
generator - - : (nn.Module) neural network generating PET images
train_loader - - : (dataloader) the training loader
test_loader - - : (dataloader) the testing loader
Returns:
df - - : (dataframe) the dataframe of the different Sets
"""
criterion = nn.CrossEntropyLoss()
val_correct_sum = 0
val_simple_cnt = 0
val_loss = 0
y_val_true = []
y_val_pred = []
val_prob_all = []
val_label_all = []
with torch.no_grad():
model.eval()
for ii, (images, _, labels) in enumerate(tqdm(valid_dataloaders)):
images, labels = images.squeeze(1).cuda(), labels.cuda()
outputs, _, _, _, _ = model(images)
val_loss += criterion(outputs, labels).item()
_, val_predicted = torch.max(outputs.data, 1)
val_correct_sum += (labels.data == val_predicted).sum().item()
val_simple_cnt += labels.size(0)
y_val_true.extend(np.ravel(np.squeeze(labels.cpu().detach().numpy())).tolist())
y_val_pred.extend(np.ravel(np.squeeze(val_predicted.cpu().detach().numpy())).tolist())
val_prob_all.extend(outputs[:,
1].cpu().detach().numpy())
val_label_all.extend(labels.cpu())

val_loss = val_loss / len(valid_dataloaders)
val_acc = val_correct_sum / val_simple_cnt
val_f1_score = f1_score(y_val_true, y_val_pred, average='weighted')
val_recall = recall_score(y_val_true, y_val_pred, average='weighted')
val_spe = recall_score(y_val_true, y_val_pred, pos_label=0, average='binary')
val_auc = roc_auc_score(val_label_all, val_prob_all, average='weighted')

print(
'Val Loss:{:.3f}...'.format(val_loss),
'Val Accuracy:{:.3f}...'.format(val_acc),
'Val AUC:{:.3f}...'.format(val_auc),
'Val F1 Score:{:.3f}'.format(val_f1_score),
'val SPE:{:.3f}...'.format(val_spe),
'Val SEN:{:.3f}...'.format(val_recall)
)


if __name__ == '__main__':
# args definition
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_ids', default='7', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers')
parser.add_argument('--init_type', type=str, default='normal',
help='network initialization [normal|xavier|kaiming|orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02,
help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--load_path', type=str, default='/data/chwang/Log/ShareGAN/Cls.pth',
help='models are saved here')
parser.add_argument('--load_size', default=256, help='Size of the original image')
parser.add_argument('--crop_size', default=128, help='Size of the patches extracted from the image')
parser.add_argument('--dataset', default="adni", type=str, help='Types of dataset [adni|aibl|nacc]')
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
# test set
test_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="test", load_size=args.load_size,
crop_size=args.crop_size, dataset=args.dataset)
test_loader = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=args.workers,
pin_memory=True) # Here are then fed to the network with a defined batch size
print('length test list:', len(test_set))
# model definition
print('initialize the model')
model = networks3D.define_Cls(2, args.init_type, args.init_gain, args.gpu_ids)
print('loading state dict from : {}'.format(args.load_path))
state_dict = torch.load(args.load_path, map_location='cuda')
model.load_state_dict(state_dict)
if len(args.gpu_ids) > 0:
assert (torch.cuda.is_available())
if len(args.gpu_ids) > 1:
model = torch.nn.DataParallel(model)
model = model.cuda()

# model evaluation
evaluate_diagNetwork(model, test_loader)
65 changes: 65 additions & 0 deletions Frame_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from components import networks3D
from components.performance_metric import mean_absolute_error, peak_signal_to_noise_ratio, structural_similarity_index
from options.test_options import TestOptions
from utils.UnpairedDataset import UnpairedDataset


def evaluate_generator(generator, test_loader, netG):
"""Evaluate a generator.
Parameters:
generator - - : (nn.Module) neural network generating PET images
train_loader - - : (dataloader) the training loader
test_loader - - : (dataloader) the testing loader
Returns:
df - - : (dataframe) the dataframe of the different Sets
"""
res_test = []

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
generator.eval()

with torch.no_grad():
for i, batch in enumerate(tqdm(test_loader)):
# Inputs MRI and PET
real_mri = batch[0].type(Tensor)
real_pet = batch[1].type(Tensor)
if netG == 'ShareSynNet':
fake_pet = generator(real_mri, alpha=0.)
else:
fake_pet = generator(real_mri)
mae = mean_absolute_error(real_pet, fake_pet).item()
psnr = peak_signal_to_noise_ratio(real_pet, fake_pet).item()
ssim = structural_similarity_index(real_pet, fake_pet).item()
res_test.append([mae, psnr, ssim])

df = pd.DataFrame([
pd.DataFrame(res_test, columns=['MAE', 'PSNR', 'SSIM']).mean().squeeze()
], index=['Test set']).T
return df


if __name__ == '__main__':
opt = TestOptions().parse()

netG = networks3D.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)
print('netG:\ttype:{}\tload_path:{}'.format(opt.netG, opt.load_path))
test_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="test", load_size=opt.load_size,
crop_size=opt.crop_size)
test_loader = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=opt.workers,
pin_memory=True) # Here are then fed to the network with a defined batch siz

print('lenght test list:', len(test_set))

state_dict = torch.load(opt.load_path)
netG.load_state_dict(state_dict)
test_df = evaluate_generator(netG, test_loader, opt.netG)
print(test_df)
114 changes: 114 additions & 0 deletions Frame_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
import random
import time

import numpy as np
import torch
from torch.utils.data import DataLoader

from models import create_model
from options.train_options import TrainOptions
from utils.UnpairedDataset import UnpairedDataset
from utils.earlystop import EarlyStopping
from utils.visualizer import Visualizer


def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


if __name__ == '__main__':

# ----- Loading the init options -----
opt = TrainOptions().parse()
# [option] to seed the seed
# seed_torch(opt.seed)

# ----- Transformation and Augmentation process for the data -----

train_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="train", load_size=opt.load_size,
crop_size=opt.crop_size)
valid_set = UnpairedDataset(data_list=['0', '1'], which_direction='AtoB', mode="valid", load_size=opt.load_size,
crop_size=opt.crop_size)
print('length train list:', len(train_set))
train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers,
pin_memory=True) # Here are then fed to the network with a defined batch size
valid_loader = DataLoader(valid_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers,
pin_memory=False) # Here are then fed to the network with a defined batch size

# initialize the early_stopping object
if opt.use_earlystop:
print('using early stop')
early_stopping = EarlyStopping(patience=opt.patience, verbose=True)

# -----------------------------------------------------
model = create_model(opt) # creation of the model
model.setup(opt)
visualizer = Visualizer(opt, train_loader)
total_steps = 0

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0

for i, data in enumerate(train_loader):
iter_start_time = time.time()
if total_steps % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
total_steps += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data)
if total_steps % opt.update_step == 0:
model.optimize_parameters(opt.update_step, True)
else:
model.optimize_parameters(opt.update_step, False)
if total_steps % opt.print_freq == 0:
losses = model.get_current_losses()
t = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)

if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, total_steps))
model.save_networks('latest')

if total_steps % opt.eval_freq == 0:
loss_G_list = []
with torch.no_grad():
model.eval() # prep model for evaluation
for i, data in enumerate(valid_loader):
# forward pass: compute predicted outputs by passing inputs to the model
model.set_input(data)
loss_G_list.append(model.get_current_losses()['G'])

# early_stopping needs the validation loss to check if it has decresed,
# and if it has, it will make a checkpoint of the current model
if opt.use_earlystop:
early_stopping(np.mean(loss_G_list), model, epoch)
if early_stopping.early_stop:
print("Early stopping from iteration")
break
if opt.use_earlystop:
if early_stopping.early_stop:
print("Early stopping from epoch")
break

iter_data_time = time.time()

if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save_networks('latest')
model.save_networks(epoch)

print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate()
Loading

0 comments on commit 444bb20

Please sign in to comment.