diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 3f35e078..9ec9f558 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -2,6 +2,7 @@ from typing import Union from typing import List + from qtpy.QtCore import Qt from qtpy.QtCore import QUrl from qtpy.QtGui import QDesktopServices diff --git a/napari_cellseg3d/log_utility.py b/napari_cellseg3d/log_utility.py index c4288f1f..1ae9b2a0 100644 --- a/napari_cellseg3d/log_utility.py +++ b/napari_cellseg3d/log_utility.py @@ -1,5 +1,7 @@ import threading +import warnings +from qtpy import QtCore from qtpy.QtGui import QTextCursor from qtpy.QtWidgets import QTextEdit @@ -22,19 +24,55 @@ def __init__(self, parent): # def receive_log(self, text): # self.print_and_log(text) + def write(self, message): + self.lock.acquire() + try: + if not hasattr(self, "flag"): + self.flag = False + message = message.replace("\r", "").rstrip() + if message: + method = "replace_last_line" if self.flag else "append" + QtCore.QMetaObject.invokeMethod( + self, + method, + QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, message), + ) + self.flag = True + else: + self.flag = False + + finally: + self.lock.release() - def print_and_log(self, text): + @QtCore.Slot(str) + def replace_last_line(self, text): + self.lock.acquire() + try: + cursor = self.textCursor() + cursor.movePosition(QTextCursor.End) + cursor.select(QTextCursor.BlockUnderCursor) + cursor.removeSelectedText() + cursor.insertBlock() + self.setTextCursor(cursor) + self.insertPlainText(text) + finally: + self.lock.release() + + def print_and_log(self, text, printing=True): """Utility used to both print to terminal and log text to a QTextEdit item in a thread-safe manner. Use only for important user info. Args: text (str): Text to be printed and logged + printing (bool): Whether to print the message as well or not using print(). Defaults to True. """ self.lock.acquire() try: - print(text) - # causes issue if you clik on terminal (tied to CMD QuickEdit mode) + if printing: + print(text) + # causes issue if you clik on terminal (tied to CMD QuickEdit mode on Windows) self.moveCursor(QTextCursor.End) self.insertPlainText(f"\n{text}") self.verticalScrollBar().setValue( @@ -42,3 +80,10 @@ def print_and_log(self, text): ) finally: self.lock.release() + + def warn(self, warning): + self.lock.acquire() + try: + warnings.warn(warning) + finally: + self.lock.release() diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 7e1b4faf..1baf6eed 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -15,7 +15,7 @@ from napari_cellseg3d.models import model_SegResNet as SegResNet from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet -from napari_cellseg3d.models import TRAILMAP_MS as TMAP +from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS from napari_cellseg3d.plugin_base import BasePluginFolder warnings.formatwarning = utils.format_Warning @@ -62,8 +62,8 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.models_dict = { "VNet": VNet, "SegResNet": SegResNet, - "TRAILMAP pre-trained": TRAILMAP, - "TRAILMAP_MS": TMAP, + "TRAILMAP": TRAILMAP, + "TRAILMAP_MS": TRAILMAP_MS, } """dict: dictionary of available models, with string for widget display as key diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 2cc6cb03..eb2f3cb9 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -1,9 +1,14 @@ import os import platform from pathlib import Path +import importlib.util +from typing import Optional +import warnings import numpy as np +from tifffile import imwrite import torch +from tqdm import tqdm # MONAI from monai.data import CacheDataset @@ -37,9 +42,10 @@ # Qt from qtpy.QtCore import Signal -from tifffile import imwrite + from napari_cellseg3d import utils +from napari_cellseg3d import log_utility # local from napari_cellseg3d.model_instance_seg import binary_connected @@ -57,17 +63,98 @@ # https://napari-staging-site.github.io/guides/stable/threading.html WEIGHTS_DIR = os.path.dirname(os.path.realpath(__file__)) + str( - Path("/models/saved_weights") + Path("/models/pretrained") ) +class WeightsDownloader: + """A utility class the downloads the weights of a model when needed.""" + + def __init__(self, log_widget: Optional[log_utility.Log] = None): + """ + Creates a WeightsDownloader, optionally with a log widget to display the progress. + + Args: + log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print() + """ + self.log_widget = log_widget + + def download_weights(self, model_name: str, model_weights_filename: str): + """ + Downloads a specific pretrained model. + This code is adapted from DeepLabCut with permission from MWMathis. + + Args: + model_name (str): name of the model to download + model_weights_filename (str): name of the .pth file expected for the model + """ + import json + import tarfile + import urllib.request + + def show_progress(count, block_size, total_size): + pbar.update(block_size) + + cellseg3d_path = os.path.split( + importlib.util.find_spec("napari_cellseg3d").origin + )[0] + pretrained_folder_path = os.path.join( + cellseg3d_path, "models", "pretrained" + ) + json_path = os.path.join( + pretrained_folder_path, "pretrained_model_urls.json" + ) + + check_path = os.path.join( + pretrained_folder_path, model_weights_filename + ) + if os.path.exists(check_path): + message = f"Weight file {model_weights_filename} already exists, skipping download step" + if self.log_widget is not None: + self.log_widget.print_and_log(message, printing=False) + print(message) + return + + with open(json_path) as f: + neturls = json.load(f) + if model_name in neturls.keys(): + url = neturls[model_name] + response = urllib.request.urlopen(url) + + start_message = f"Downloading the model from the M.W. Mathis Lab server {url}...." + total_size = int(response.getheader("Content-Length")) + if self.log_widget is None: + print(start_message) + pbar = tqdm(unit="B", total=total_size, position=0) + else: + self.log_widget.print_and_log(start_message) + pbar = tqdm( + unit="B", + total=total_size, + position=0, + file=self.log_widget, + ) + + filename, _ = urllib.request.urlretrieve( + url, reporthook=show_progress + ) + with tarfile.open(filename, mode="r:gz") as tar: + tar.extractall(pretrained_folder_path) + else: + raise ValueError( + f"Unknown model: {model_name}. Should be one of {', '.join(neturls)}" + ) + + class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_""" + Separate from Worker instances as indicated `here`_""" # TODO link ? log_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" + warn_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" # Should not be an instance variable but a class variable, not defined in __init__, see # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect @@ -128,6 +215,7 @@ def __init__( super().__init__(self.inference) self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal + self.warn_signal = self._signals.warn_signal ########################################### ########################################### self.device = device @@ -142,9 +230,11 @@ def __init__( self.window_infer_size = window_infer_size self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv - """These attributes are all arguments of :py:func:~inference, please see that for reference""" + self.downloader = WeightsDownloader() + """Download utility""" + @staticmethod def create_inference_dict(images_filepaths): """Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths` @@ -154,6 +244,9 @@ def create_inference_dict(images_filepaths): data_dicts = [{"image": image_name} for image_name in images_filepaths] return data_dicts + def set_download_log(self, widget): + self.downloader.log_widget = widget + def log(self, text): """Sends a signal that ``text`` should be logged @@ -162,6 +255,10 @@ def log(self, text): """ self.log_signal.emit(text) + def warn(self, warning): + """Sends a warning to main thread""" + self.warn_signal.emit(warning) + def log_parameters(self): self.log("-" * 20) @@ -233,8 +330,8 @@ def inference(self): """ sys = platform.system() print(f"OS is {sys}") - if sys == "Darwin": # required for macOS ? - torch.set_num_threads(1) + if sys == "Darwin": + torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") images_dict = self.create_inference_dict(self.images_filepaths) @@ -260,7 +357,11 @@ def inference(self): model = self.model_dict["class"].get_net() if self.model_dict["name"] == "SegResNet": model = self.model_dict["class"].get_net()( - input_image_size=[dims, dims, dims], # TODO FIX ! + input_image_size=[ + dims, + dims, + dims, + ], # TODO FIX ! find a better way & remove model-specific code out_channels=1, # dropout_prob=0.3, ) @@ -304,12 +405,18 @@ def inference(self): # print(weights) self.log( "\nLoading weights..." - ) # TODO add try/except for invalid weights + ) # TODO add try/except for invalid weights for proper reset if self.weights_dict["custom"]: weights = self.weights_dict["path"] else: - weights = os.path.join(WEIGHTS_DIR, self.weights_dict["path"]) + self.downloader.download_weights( + self.model_dict["name"], + self.model_dict["class"].get_weights_file(), + ) + weights = os.path.join( + WEIGHTS_DIR, self.model_dict["class"].get_weights_file() + ) model.load_state_dict( torch.load( @@ -544,12 +651,13 @@ def __init__( """ - - print("init") super().__init__(self.train) self._signals = LogSignal() self.log_signal = self._signals.log_signal + self.warn_signal = self._signals.warn_signal + self._weight_error = False + ############################################# self.device = device self.model_dict = model_dict self.weights_path = weights_path @@ -571,7 +679,11 @@ def __init__( self.train_files = [] self.val_files = [] - print("end init") + ####################################### + self.downloader = WeightsDownloader() + + def set_download_log(self, widget): + self.downloader.log_widget = widget def log(self, text): """Sends a signal that ``text`` should be logged @@ -581,6 +693,10 @@ def log(self, text): """ self.log_signal.emit(text) + def warn(self, warning): + """Sends a warning to main thread""" + self.warn_signal.emit(warning) + def log_parameters(self): self.log("-" * 20) @@ -624,12 +740,19 @@ def log_parameters(self): if self.weights_path is not None: self.log(f"Using weights from : {self.weights_path}") + if self._weight_error: + self.log( + ">>>>>>>>>>>>>>>>>\n" + "WARNING:\nChosen weights were incompatible with the model,\n" + "the model will be trained from random weights\n" + "<<<<<<<<<<<<<<<<<\n" + ) # self.log("\n") self.log("-" * 20) def train(self): - """Trains the Pytorch model for the given number of epochs, with the selected model and data, + """Trains the PyTorch model for the given number of epochs, with the selected model and data, using the chosen batch size, validation interval, loss function, and number of samples. Will perform validation once every :py:obj:`val_interval` and save results if the mean dice is better @@ -838,17 +961,27 @@ def train(self): if self.weights_path is not None: if self.weights_path == "use_pretrained": weights_file = model_class.get_weights_file() + self.downloader.download_weights(model_name, weights_file) weights = os.path.join(WEIGHTS_DIR, weights_file) self.weights_path = weights else: weights = os.path.join(self.weights_path) - model.load_state_dict( - torch.load( - weights, - map_location=self.device, + try: + model.load_state_dict( + torch.load( + weights, + map_location=self.device, + ) ) - ) + except RuntimeError: + warn = ( + "WARNING:\nIt seems the weights were incompatible with the model,\n" + "the model will be trained from random weights" + ) + self.log(warn) + self.warn(warn) + self._weight_error = True if self.device.type == "cuda": self.log("\nUsing GPU :") diff --git a/napari_cellseg3d/models/TRAILMAP_MS.py b/napari_cellseg3d/models/TRAILMAP_MS.py deleted file mode 100644 index 9905c71a..00000000 --- a/napari_cellseg3d/models/TRAILMAP_MS.py +++ /dev/null @@ -1,126 +0,0 @@ -import os - -import torch -from torch import nn - -from napari_cellseg3d import utils - - -def get_weights_file(): - # model additionally trained on Mathis/Wyss mesoSPIM data - target_dir = utils.download_model("TRAILMAP_MS") - return os.path.join(target_dir, "TRAILMAP_MS_best_metric_epoch_26.pth") - - -def get_net(): - return TRAILMAP_MS(1, 1) - - -def get_output(model, input): - out = model(input) - - return out - - -def get_validation(model, val_inputs): - - return model(val_inputs) - - -class TRAILMAP_MS(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.conv0 = self.encoderBlock(in_ch, 32, 3) # input - self.conv1 = self.encoderBlock(32, 64, 3) # l1 - self.conv2 = self.encoderBlock(64, 128, 3) # l2 - self.conv3 = self.encoderBlock(128, 256, 3) # l3 - - self.bridge = self.bridgeBlock(256, 512, 3) - - self.up5 = self.decoderBlock(256 + 512, 256, 2) - - self.up6 = self.decoderBlock(128 + 256, 128, 2) - self.up7 = self.decoderBlock(128 + 64, 64, 2) # l2 - self.up8 = self.decoderBlock(64 + 32, 32, 2) # l1 - self.out = self.outBlock(32, out_ch, 1) - - def forward(self, x): - - conv0 = self.conv0(x) # l0 - conv1 = self.conv1(conv0) # l1 - conv2 = self.conv2(conv1) # l2 - conv3 = self.conv3(conv2) # l3 - - bridge = self.bridge(conv3) # bridge - # print("bridge :") - # print(bridge.shape) - - up5 = self.up5(torch.cat([conv3, bridge], 1)) # l3 - # print("up") - # print(up5.shape) - up6 = self.up6(torch.cat([up5, conv2], 1)) # l2 - # print(up6.shape) - up7 = self.up7(torch.cat([up6, conv1], 1)) # l1 - # print(up7.shape) - - up8 = self.up8(torch.cat([up7, conv0], 1)) # l1 - # print(up8.shape) - out = self.out(up8) - # print("out:") - # print(out.shape) - return out - - def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - - encode = nn.Sequential( - nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_ch), - nn.ReLU(), - nn.Conv3d( - out_ch, out_ch, kernel_size=kernel_size, padding=padding - ), - nn.BatchNorm3d(out_ch), - nn.ReLU(), - nn.MaxPool3d(2), - ) - return encode - - def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): - - encode = nn.Sequential( - nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_ch), - nn.ReLU(), - nn.Conv3d( - out_ch, out_ch, kernel_size=kernel_size, padding=padding - ), - nn.BatchNorm3d(out_ch), - nn.ReLU(), - # nn.ConvTranspose3d(out_ch, out_ch, kernel_size=2, stride=2), - ) - return encode - - def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - - decode = nn.Sequential( - nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_ch), - nn.ReLU(), - nn.Conv3d( - out_ch, out_ch, kernel_size=kernel_size, padding=padding - ), - nn.BatchNorm3d(out_ch), - nn.ReLU(), - nn.ConvTranspose3d( - out_ch, out_ch, kernel_size=kernel_size, stride=(2, 2, 2) - ), - ) - return decode - - def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): - - out = nn.Sequential( - nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), - # nn.BatchNorm3d(out_ch), - ) - return out diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index 98e07c49..41dc3bde 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -1,17 +1,12 @@ -import os - from monai.networks.nets import SegResNetVAE -from napari_cellseg3d import utils - def get_net(): return SegResNetVAE def get_weights_file(): - target_dir = utils.download_model("SegResNet") - return os.path.join(target_dir, "SegResNet.pth") + return "SegResNet.pth" def get_output(model, input): diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/models/model_TRAILMAP.py index 0c056032..ec4cfdbb 100644 --- a/napari_cellseg3d/models/model_TRAILMAP.py +++ b/napari_cellseg3d/models/model_TRAILMAP.py @@ -1,17 +1,15 @@ -import os - -from napari_cellseg3d import utils -from napari_cellseg3d.models.unet.model import UNet3D +import torch +from torch import nn def get_weights_file(): - # original model from Liqun Luo lab, transfered to pytorch - target_dir = utils.download_model("TRAILMAP") - return os.path.join(target_dir, "TRAILMAP_PyTorch.pth") + # model additionally trained on Mathis/Wyss mesoSPIM data + return "TRAILMAP.pth" + # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them def get_net(): - return UNet3D(1, 1) + return TRAILMAP(1, 1) def get_output(model, input): @@ -23,3 +21,102 @@ def get_output(model, input): def get_validation(model, val_inputs): return model(val_inputs) + + +class TRAILMAP(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv0 = self.encoderBlock(in_ch, 32, 3) # input + self.conv1 = self.encoderBlock(32, 64, 3) # l1 + self.conv2 = self.encoderBlock(64, 128, 3) # l2 + self.conv3 = self.encoderBlock(128, 256, 3) # l3 + + self.bridge = self.bridgeBlock(256, 512, 3) + + self.up5 = self.decoderBlock(256 + 512, 256, 2) + + self.up6 = self.decoderBlock(128 + 256, 128, 2) + self.up7 = self.decoderBlock(128 + 64, 64, 2) # l2 + self.up8 = self.decoderBlock(64 + 32, 32, 2) # l1 + self.out = self.outBlock(32, out_ch, 1) + + def forward(self, x): + + conv0 = self.conv0(x) # l0 + conv1 = self.conv1(conv0) # l1 + conv2 = self.conv2(conv1) # l2 + conv3 = self.conv3(conv2) # l3 + + bridge = self.bridge(conv3) # bridge + # print("bridge :") + # print(bridge.shape) + + up5 = self.up5(torch.cat([conv3, bridge], 1)) # l3 + # print("up") + # print(up5.shape) + up6 = self.up6(torch.cat([up5, conv2], 1)) # l2 + # print(up6.shape) + up7 = self.up7(torch.cat([up6, conv1], 1)) # l1 + # print(up7.shape) + + up8 = self.up8(torch.cat([up7, conv0], 1)) # l1 + # print(up8.shape) + out = self.out(up8) + # print("out:") + # print(out.shape) + return out + + def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): + + encode = nn.Sequential( + nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_ch), + nn.ReLU(), + nn.Conv3d( + out_ch, out_ch, kernel_size=kernel_size, padding=padding + ), + nn.BatchNorm3d(out_ch), + nn.ReLU(), + nn.MaxPool3d(2), + ) + return encode + + def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): + + encode = nn.Sequential( + nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_ch), + nn.ReLU(), + nn.Conv3d( + out_ch, out_ch, kernel_size=kernel_size, padding=padding + ), + nn.BatchNorm3d(out_ch), + nn.ReLU(), + # nn.ConvTranspose3d(out_ch, out_ch, kernel_size=2, stride=2), + ) + return encode + + def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): + + decode = nn.Sequential( + nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_ch), + nn.ReLU(), + nn.Conv3d( + out_ch, out_ch, kernel_size=kernel_size, padding=padding + ), + nn.BatchNorm3d(out_ch), + nn.ReLU(), + nn.ConvTranspose3d( + out_ch, out_ch, kernel_size=kernel_size, stride=(2, 2, 2) + ), + ) + return decode + + def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): + + out = nn.Sequential( + nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), + # nn.BatchNorm3d(out_ch), + ) + return out diff --git a/napari_cellseg3d/models/model_TRAILMAP_MS.py b/napari_cellseg3d/models/model_TRAILMAP_MS.py new file mode 100644 index 00000000..1ee50158 --- /dev/null +++ b/napari_cellseg3d/models/model_TRAILMAP_MS.py @@ -0,0 +1,21 @@ +from napari_cellseg3d.models.unet.model import UNet3D + + +def get_weights_file(): + # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) + return "TRAILMAP_MS_best_metric_epoch_26.pth" + + +def get_net(): + return UNet3D(1, 1) + + +def get_output(model, input): + out = model(input) + + return out + + +def get_validation(model, val_inputs): + + return model(val_inputs) diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/models/model_VNet.py index 2b3d758b..0c5f0b75 100644 --- a/napari_cellseg3d/models/model_VNet.py +++ b/napari_cellseg3d/models/model_VNet.py @@ -1,18 +1,13 @@ -import os - from monai.inferers import sliding_window_inference from monai.networks.nets import VNet -from napari_cellseg3d import utils - def get_net(): return VNet() def get_weights_file(): - target_dir = utils.download_model("VNet") - return os.path.join(target_dir, "VNet_40e.pth") + return "VNet_40e.pth" def get_output(model, input): diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index 33f7ca39..711f4b49 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -534,7 +534,6 @@ def start(self): else: weights_dict = { "custom": False, - "path": self.get_model(model_key).get_weights_file(), } if self.anisotropy_wdgt.is_enabled(): @@ -591,6 +590,7 @@ def start(self): keep_on_cpu=self.keep_on_cpu, stats_csv=self.stats_to_csv, ) + self.worker.set_download_log(self.log) yield_connect_show_res = lambda data: self.on_yield( data, @@ -599,6 +599,7 @@ def start(self): self.worker.started.connect(self.on_start) self.worker.log_signal.connect(self.log.print_and_log) + self.worker.warn_signal.connect(self.log.warn) self.worker.yielded.connect(yield_connect_show_res) self.worker.errored.connect( yield_connect_show_res diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index d8090182..517cf8fc 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -35,7 +35,7 @@ class Trainer(ModelFramework): - """A plugin to train pre-defined Pytorch models for one-channel segmentation directly in napari. + """A plugin to train pre-defined PyTorch models for one-channel segmentation directly in napari. Features parameter selection for training, dynamic loss plotting and automatic saving of the best weights during training through validation.""" @@ -852,10 +852,12 @@ def start(self): do_augmentation=self.augment_choice.isChecked(), deterministic=seed_dict, ) + self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] self.worker.log_signal.connect(self.log.print_and_log) + self.worker.warn_signal.connect(self.log.warn) self.worker.started.connect(self.on_start) @@ -989,6 +991,11 @@ def on_yield(data, widget): def make_csv(self): size_column = range(1, self.max_epochs + 1) + + if len(self.loss_values) == 0 or self.loss_values is None: + warnings.warn("No loss values to add to csv !") + return + self.df = pd.DataFrame( { "epoch": size_column, diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index fd28d13d..bc725fc1 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,9 +1,9 @@ -import importlib.util import os import warnings from datetime import datetime from pathlib import Path + import cv2 import numpy as np from dask_image.imread import imread as dask_imread @@ -978,44 +978,3 @@ def merge_imgs(imgs, original_image_shape): print(merged_imgs.shape) return merged_imgs - - -def download_model(modelname): - """ - Downloads a specific pretained model. - This code is adapted from DeepLabCut with permission from MWMathis - """ - import json - import tarfile - import urllib.request - - def show_progress(count, block_size, total_size): - pbar.update(block_size) - - cellseg3d_path = os.path.split( - importlib.util.find_spec("napari_cellseg3d").origin - )[0] - pretrained_folder_path = os.path.join( - cellseg3d_path, "models", "pretrained" - ) - json_path = os.path.join( - pretrained_folder_path, "pretrained_model_urls.json" - ) - with open(json_path) as f: - neturls = json.load(f) - if modelname in neturls.keys(): - url = neturls[modelname] - response = urllib.request.urlopen(url) - print( - f"Downloading the model from the M.W. Mathis Lab server {url}...." - ) - total_size = int(response.getheader("Content-Length")) - pbar = tqdm(unit="B", total=total_size, position=0) - filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) - with tarfile.open(filename, mode="r:gz") as tar: - tar.extractall(pretrained_folder_path) - return pretrained_folder_path - else: - raise ValueError( - f"Unknown model. `modelname` should be one of {', '.join(neturls)}" - )