Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make initial training pipeline #7

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
/synthlung.egg*/
/build/*
*__pycache__*
.coverage
.coverage
.pth
67 changes: 55 additions & 12 deletions synthlung/__main__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import argparse
import json
import tqdm

from synthlung.utils.tumor_isolation_pipeline import TumorCropPipeline
from synthlung.utils.dataset_formatter import MSDImageSourceFormatter, MSDGenerateJSONFormatter
from synthlung.utils.dataset_formatter import MSDImageSourceFormatter, JsonSeedGenerator, JsonTrainingGenerator
from synthlung.utils.tumor_insertion_pipeline import InsertTumorPipeline
from synthlung.utils.lung_segmentation_pipeline import LungMaskPipeline, HostJsonGenerator

from synthlung.train_pipeline.train import TrainPipeline

from synthlung.action_provider.log_remote_provider import LogRemoteProvider
from synthlung.action_provider.log_local_provider import LogLocalProvider

from lungmask import LMInferer
import json

def seed():
json_file_path = "./assets/source/dataset.json"
json_file_path = "./assets/images/source/dataset.json"

with open(json_file_path, 'r') as json_file:
image_dict = json.load(json_file)
crop_pipeline = TumorCropPipeline()
crop_pipeline(image_dict)
formatter = MSDGenerateJSONFormatter("./assets/seeds/")
formatter.generate_json()
formatter = JsonSeedGenerator("./assets/images/seeds/")
formatter.generate_json_seeds()

def format_msd():
formatter = MSDImageSourceFormatter()
Expand All @@ -23,32 +31,62 @@ def format_msd():

def generate_randomized_tumors():
tumor_inserter = InsertTumorPipeline()
json_file_path = "./assets/source/dataset.json"
json_file_path = "./assets/images/source/dataset.json"
with open(json_file_path, 'r') as json_file:
image_dict = json.load(json_file)

json_seed_path = "./assets/seeds/dataset.json"
json_seed_path = "./assets/images/seeds/dataset.json"
with open(json_seed_path, 'r') as json_file:
seeds_dict = json.load(json_file)

tumor_inserter(image_dict, seeds_dict)
round_dict = tumor_inserter.getDict()

path = round_dict[0]["randomized_image"].split("0_image")[0]
formatter = JsonTrainingGenerator(path)
formatter.generate_json()
return path

def mask_hosts():
lung_masker = LMInferer()
host_masker = LungMaskPipeline(lung_masker)
json_file_path = "./assets/source/dataset.json"
json_file_path = "./assets/images/source/dataset.json"
with open(json_file_path, 'r') as json_file:
image_dict = json.load(json_file)

host_masker(image_dict)
json_generator = HostJsonGenerator('./assets/hosts/')
json_generator = HostJsonGenerator('./assets/images/hosts/')
json_generator.generate_json()

def train(config_path):
path = "./synthlung/config.json"
with open(config_path, "r") as f:
data = json.load(f)

logLocal = LogLocalProvider()

trainPipeline = TrainPipeline(logLocal)
trainPipeline.verify_config()
trainPipeline()

exit(0)

def train_remote():
print("Now starting remote session")

remote_handler = LogRemoteProvider()

trainPipeline = TrainPipeline(remote_handler)
trainPipeline.verify_config()
trainPipeline()


def main():
parser = argparse.ArgumentParser(description="Create your synthetic lung tumors!")

parser.add_argument("action", choices=["format", "seed", "host", "generate"], help="Action to perform")
parser.add_argument("action", choices=["format", "seed", "host", "generate", "train", "train_remote"], help="Action to perform")
parser.add_argument("--dataset", help="Dataset to format", choices=["msd"])
parser.add_argument("--config", help="Path to config to configure training")
args = parser.parse_args()

if args.action == "format":
Expand All @@ -57,10 +95,15 @@ def main():
elif args.action == "seed":
seed()
elif args.action == "generate":
if(args.dataset == "msd"):
generate_randomized_tumors()
generate_randomized_tumors()
elif args.action == "host":
if(args.dataset == "msd"):
mask_hosts()
elif args.action == "train":
config_path = "./synthlung/config.json"
#config_path = args.config # hardcoded for easy debugging
train(config_path)
elif args.action == "train_remote":
train_remote()
else:
print("Action not recognized")
Empty file.
19 changes: 19 additions & 0 deletions synthlung/action_provider/action_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import abstractmethod, ABC

class ActionProvider(ABC):

@abstractmethod
def start_new_session(self) -> None:
pass

@abstractmethod
def get_session_id(self) -> None:
pass

@abstractmethod
def get_config(self) -> None:
pass

@abstractmethod
def save_log(self, log_dict) -> None:
pass
22 changes: 22 additions & 0 deletions synthlung/action_provider/log_local_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json

from synthlung.action_provider.action_provider import ActionProvider

class LogLocalProvider(ActionProvider):
def __init__(self):
pass

def start_new_session(self):
pass

def get_session_id(self):
return 0

def get_config(self):
config_path = "./synthlung/config.json"
with open(config_path, "r") as f:
data = json.load(f)
return data

def save_log(self, log_dict):
pass
51 changes: 51 additions & 0 deletions synthlung/action_provider/log_remote_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import requests
import json

from synthlung.action_provider.action_provider import ActionProvider

class LogRemoteProvider(ActionProvider):
def __init__(self):
self.start_session_url = "http://localhost:5167/api/Session/newSession"
self.new_log_url = "http://localhost:5167/api/Log"
self.get_config_url = "http://localhost:5167/api/Config/000000000000000000000000"
self.header = {"Content-Type" : "application/json"}
self.started = False
self.verify = False

def start_new_session(self):
response = requests.get(self.start_session_url, headers=self.header, verify=self.verify)

if (response.status_code == 200):
response_data = response.json()
self.started = True
self.session_id = response_data["id"]

def get_session_id(self):
if (self.started):
return self.session_id

print("Session has not started yet")
exit(2)

def get_config(self):
response = requests.get(self.get_config_url, headers=self.header, verify=self.verify)

if (response.status_code == 200):
return response.json()

print("Could not fetch remote config")
exit(3)

def save_log(self, log_dict):
if (not self.started):
return

session_id_dict = {"Session_Id": self.session_id}
log_dict = {**log_dict, **session_id_dict}
requests.post(self.new_log_url, data=json.dumps(log_dict), headers=self.header, verify=self.verify)


if (__name__ == "__main__"):
r = LogRemoteProvider()
r.start_new_session()
print(r.get_session_id())
8 changes: 8 additions & 0 deletions synthlung/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"learningRate": 0.001,
"loss": "CE",
"optimizer": "Adam",
"network": "BasicUnet",
"trainImagesDatasetPath": "./assets/images/source/dataset.json",
"modelSavePath": "./model.pth"
}
Empty file added synthlung/dataset/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions synthlung/dataset/synthlung_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch.utils.data import Dataset
from monai.transforms import (Compose, LoadImaged, ToTensord, DivisiblePadd, Resized, AddChanneld)
from synthlung.utils.send_to_cudad import SendToCudad

class SynthlungDataset(Dataset):
def __init__(self, data: [dict]):
self.data = data
self.compose_load = Compose([
LoadImaged(keys=['image', 'label']),
DivisiblePadd(keys=['image', 'label'], k=16),
AddChanneld(keys=['image', 'label']),
Resized(keys=['image', 'label'], spatial_size=(64, 64, -1,)),
ToTensord(keys=['image', 'label']),
SendToCudad(keys=['image', 'label'])
])

def __len__(self):
return len(self.data)

def __getitem__(self, index):
loaded_data = self.compose_load(self.data[index])
return loaded_data["image"], loaded_data["label"]
Empty file added synthlung/networks/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions synthlung/networks/simple_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.nn as nn

class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.conv3d = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv3d(x)
x = self.relu(x)
return x
Empty file added synthlung/providers/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions synthlung/providers/loss_function_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch.nn as nn

class LossFunctionProvider():
def __init__(self, config):
if (config["loss"] == "CE"):
self.criterion = nn.CrossEntropyLoss()

def __call__(self):
return self.criterion
12 changes: 12 additions & 0 deletions synthlung/providers/network_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from synthlung.networks.simple_network import SimpleNN
from monai.networks.nets.basic_unet import BasicUNet

class NetworkProvider():
def __init__(self, config) -> None:
if (config["network"] == "SimpleNN"):
self.network = SimpleNN()
elif(config['network'] == "BasicUnet"):
self.network = BasicUNet(spatial_dims=3, in_channels=1, out_channels=1)

def __call__(self):
return self.network
10 changes: 10 additions & 0 deletions synthlung/providers/optimizer_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch.optim as optim

class OptimizerProvider():
def __init__(self, config):
self.lr = config["learningRate"]
if (config["optimizer"] == "Adam"):
self.optimizer = optim.Adam

def __call__(self, model):
return self.optimizer(model.parameters(), lr=self.lr)
Empty file.
38 changes: 38 additions & 0 deletions synthlung/train_pipeline/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import monai.config
import json

from synthlung.dataset.synthlung_dataset import SynthlungDataset
from synthlung.providers.loss_function_provider import LossFunctionProvider
from synthlung.providers.network_provider import NetworkProvider
from synthlung.providers.optimizer_provider import OptimizerProvider
from synthlung.train_pipeline.trainer import Trainer
from torch.utils.data import DataLoader

class TrainPipeline():
monai.config.BACKEND = "Nibabel"
def __init__(self, provider) -> None:
self.provider = provider
self.config = self.provider.get_config()
self.loss_function_provider = LossFunctionProvider(self.config)
self.model_provider = NetworkProvider(self.config)
self.optimizer_provider = OptimizerProvider(self.config)

model = self.model_provider()
model.to(device='cuda')

self.trainer = Trainer(model, self.optimizer_provider(self.model_provider()), self.loss_function_provider(), self.config["modelSavePath"])

with open(self.config["trainImagesDatasetPath"], "r") as f:
self.data = json.load(f)
self.dataset = SynthlungDataset(self.data)
self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=True)

def __call__(self) -> None:
self.trainer(self.dataloader, 1)

def verify_config(self) -> None:
if (monai.config.deviceconfig.get_gpu_info()["Has CUDA"]):
print(f"Running on: {monai.config.deviceconfig.get_gpu_info()['GPU 0 Name']}")
return
print("No CUDA found. Check config!")
exit(1)
49 changes: 49 additions & 0 deletions synthlung/train_pipeline/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import torch as T

class Trainer():
def __init__(self, model, optimizer, loss_function, save_weight_path, logger = None) -> None:
self.M = model
self.O = optimizer
self.L = loss_function
self.save_path = save_weight_path

self.train_losses = []
self.current_epoch = 0
self.logger = logger

def __call__(self, dataloader, n_epochs = 20):
self.current_epoch = 0
losses = []
epochs = []
for epoch in range(n_epochs):
print(f"Epoch {epoch}/{n_epochs}")
N = len(dataloader)
epochs += [epoch+i/N for i in range(N)]

epoch_losses = self._train_one_epoch(dataloader, self.M)
losses += epoch_losses

if (self.logger != None):
pass

self._save_model(self.save_path)
return np.array(epochs), np.array(losses)

def _train_one_epoch(self, dataloader, model):
losses = []
for i, (x, y) in enumerate(dataloader):
self.O.zero_grad()
loss = self.L(model(x), y)
loss.backward()
self.O.step()
print(loss.item())

losses.append(loss.item())
return losses

def _validate(self, dataloader):
pass

def _save_model(self, path):
T.save(self.M.state_dict(), path)
Loading
Loading