-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain-boostingNet.py
87 lines (78 loc) · 3.72 KB
/
train-boostingNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import boosting_data
import config
import data
import wandb
from nets import ResNext
plt.style.use('ggplot')
def main():
boosting_number = 5
for n in range(0,boosting_number):
model = ResNext()
#data_df = pd.read_csv(config.CSV_PATH)
#train_df, val_df = train_test_split(data_df, test_size=0.1)
training_img_path = '../Training_Set/Training/'
evaluation_img_path = '../Evaluation_Set/Validation'
train_df = '../Training_Set/RFMiD_Training_Labels.csv'
val_df = '../Evaluation_Set/RFMiD_Validation_Labels.csv'
weight_path = f'boosting/weight_b{n}.csv'
trainset = boosting_data.ISBIDataset(train_df, training_img_path, weight_path, testing=False, input_size=732)
valset = boosting_data.ISBIDataset(val_df, evaluation_img_path, weight_csv = None, testing=True, input_size=732)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=20)
valloader = DataLoader(valset, batch_size=16, shuffle=False, num_workers=20)
wandb.init(project='boosting-ResNext')
wandb_logger = WandbLogger(project='boosting-ResNext')
checkpoint_callback = ModelCheckpoint(monitor='val_loss',
#save_last = True,
dirpath='boosting/checkpoints_ResNext101_732/checkpoints',
#every_n_val_epochs = 10,
filename=f'ISBI-WeightedBCE-boosting-ResNext101_732x732-b{n}'+'-{epoch:03d}-{val_loss:.4f}',
save_top_k=3,
mode='min')
trainer = pl.Trainer(gpus=config.DEVICES,
# num_nodes=2,
logger=wandb_logger,
log_every_n_steps=config.LOG_STEP,
callbacks=[checkpoint_callback], max_epochs=15)
trainer.fit(model, trainloader, val_dataloaders=valloader)
print("Finished Training")
N = len(valset)
batch_size = 16
outs_valid = np.zeros((N, 29))
labels_valid = np.zeros((N, 29))
for i, (imgs, label, w) in enumerate(tqdm.tqdm(valloader)):
idx = i * batch_size
imgs = imgs.type(torch.FloatTensor)
out = model(imgs).detach().cpu().numpy()
outs_valid[idx:idx + len(out),:] = out
labels_valid[idx:idx + len(label),:] = label.detach().cpu().numpy()
sig = torch.nn.Sigmoid()
weight = np.zeros((29,))
count = np.zeros((29,))
rounded_valid_pred = np.round(sig(torch.tensor(outs_valid)).numpy()).astype('int')
for i in range(labels_valid.shape[0]):
for j in range(labels_valid.shape[1]):
if labels_valid[i][j] != rounded_valid_pred[i][j]:
weight[j]+=1
if labels_valid[i][j] == 1:
count[j] += 1
#print(weight)
weight[weight==0]=1
weight = weight/count
#weight = weight/min(weight)
weight_df = pd.DataFrame(weight)
next_id = n+1
weight_df.to_csv(f'boosting/checkpoints_ResNext101_732/weight_b{next_id}.csv', index=False)
#print(weight)
#print(count)
if __name__ == "__main__":
main()