Skip to content

Commit

Permalink
Improve LayerSelecter + Tooltips update (#64)
Browse files Browse the repository at this point in the history
* Improved layer selecter code

- Better handling of removed/added layers independent of widget opening time
- Still missing layer renaming events handling

* Working layer rename

* Tooltips update

* Try to fix tests on Win

* Disable tests on Win for now

* Disable wandb for tests

* Small WandB improvement

* Change logger back to debug

* Fix layer addition and WandB error

* Fix bug with missing variable in some cases

* Fix WandB project name for WNet

* Fix make_csv
  • Loading branch information
C-Achard authored Mar 4, 2024
1 parent 4fa481b commit 570666a
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 13 deletions.
2 changes: 1 addition & 1 deletion napari_cellseg3d/_tests/test_plugin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_utils_plugin(make_napari_viewer_proxy):
view = make_napari_viewer_proxy()
widget = Utilities(view)

image = rand_gen.random((10, 10, 10)).astype(np.uint8)
image = rand_gen.random((10, 10, 10)) # .astype(np.uint8)
image_layer = view.add_image(image, name="image")
label_layer = view.add_labels(image.astype(np.uint8), name="labels")

Expand Down
2 changes: 2 additions & 0 deletions napari_cellseg3d/_tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)
from napari_cellseg3d.config import MODEL_LIST

WANDB_MODE = "disabled"

im_path = Path(__file__).resolve().parent / "res/test.tif"
im_path_str = str(im_path)
lab_path = Path(__file__).resolve().parent / "res/test_labels.tif"
Expand Down
15 changes: 15 additions & 0 deletions napari_cellseg3d/code_plugins/plugin_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None):
self.container = self._build()

self.function = clear_small_objects
self._set_tooltips()

def _set_tooltips(self):
self.size_for_removal_counter.setToolTip(
"Size of the objects to remove, in pixels."
)

def _build(self):
container = ui.ContainerWidget()
Expand Down Expand Up @@ -647,6 +653,15 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None):
self.container = self._build()
self.function = threshold

self._set_tooltips()

def _set_tooltips(self):
self.binarize_counter.setToolTip(
"Value to use as threshold for binarization."
"For labels, use the highest ID you want to keep. All lower IDs will be removed."
"For images, use the intensity value (pixel value) to threshold the image."
)

def _build(self):
container = ui.ContainerWidget()

Expand Down
4 changes: 4 additions & 0 deletions napari_cellseg3d/code_plugins/plugin_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def _set_tooltips(self):
)

thresh_desc = (
"NOT RECOMMENDED ON FIRST RUN - check results without first!\n"
"Thresholding : all values in the image below the chosen probability"
" threshold will be set to 0, and all others to 1."
)
Expand All @@ -301,6 +302,7 @@ def _set_tooltips(self):
"If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA"
)
self.use_instance_choice.setToolTip(
"NOT RECOMMENDED ON FIRST RUN - check results without first!\n"
"Instance segmentation will convert instance (0/1) labels to labels"
" that attempt to assign an unique ID to each cell."
)
Expand Down Expand Up @@ -653,6 +655,8 @@ def _display_results(self, result: InferenceResult):
if result.semantic_segmentation[channel].sum() > 0:
index_channel_least_labelled = channel
break
# if no channel has any label, use the first one
index_channel_least_labelled = 0
viewer.dims.set_point(
0, index_channel_least_labelled
) # TODO(cyril: check if this is always the right axis
Expand Down
10 changes: 5 additions & 5 deletions napari_cellseg3d/code_plugins/plugin_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ def on_finish(self):
self.log.print_and_log("*" * 10)
try:
self._make_csv()
except ValueError as e:
except (ValueError, KeyError) as e:
logger.warning(f"Error while saving CSV report: {e}")

self.start_btn.setText("Start")
Expand Down Expand Up @@ -1375,11 +1375,11 @@ def _make_csv(self):
try:
self.loss_1_values["Loss"]
supervised = True
except KeyError("Loss"):
except KeyError:
try:
self.loss_1_values["SoftNCuts"]
supervised = False
except KeyError("SoftNCuts") as e:
except KeyError as e:
raise KeyError(
"Error when making csv. Check loss dict keys ?"
) from e
Expand All @@ -1398,8 +1398,8 @@ def _make_csv(self):
"validation": val,
}
)
if len(val) != len(self.loss_1_values):
err = f"Validation and loss values don't have the same length ! Got {len(val)} and {len(self.loss_1_values)}"
if len(val) != len(self.loss_1_values["Loss"]):
err = f"Validation and loss values don't have the same length ! Got {len(val)} and {len(self.loss_1_values['Loss'])}"
logger.error(err)
raise ValueError(err)
else:
Expand Down
96 changes: 89 additions & 7 deletions napari_cellseg3d/interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""User interface functions and aliases."""
import contextlib
import threading
from functools import partial
from typing import List, Optional
from warnings import warn

import napari

Expand Down Expand Up @@ -804,9 +806,12 @@ def __init__(
self.layer_description.setVisible(False)
# self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ?

# connect to LayerList events
self._viewer.layers.events.inserted.connect(partial(self._add_layer))
self._viewer.layers.events.removed.connect(partial(self._remove_layer))
self._viewer.layers.events.changed.connect(self._check_for_layers)

# update self.layer_list when layers are added or removed
self.layer_list.currentIndexChanged.connect(self._update_tooltip)
self.layer_list.currentTextChanged.connect(self._update_description)

Expand All @@ -816,20 +821,81 @@ def __init__(
)
self._check_for_layers()

def _get_all_layers(self):
return [
self.layer_list.itemText(i) for i in range(self.layer_list.count())
]

def _check_for_layers(self):
"""Check for layers of the correct type and update the dropdown menu.
Also removes layers that have been removed from the viewer.
"""
for layer in self._viewer.layers:
if isinstance(layer, self.layer_type):
layer.events.name.connect(self._rename_layer)

if (
isinstance(layer, self.layer_type)
and layer.name not in self._get_all_layers()
):
logger.debug(
f"Layer {layer.name} - List : {self._get_all_layers()}"
)
# add new layers of correct type
self.layer_list.addItem(layer.name)
logger.debug(f"Layer {layer.name} has been added to the menu")
# break
# once added, check again for previously renamed layers
self._check_for_removed_layer(layer)

if layer.name in self._get_all_layers() and not isinstance(
layer, self.layer_type
):
# remove layers of incorrect type
index = self.layer_list.findText(layer.name)
self.layer_list.removeItem(index)
logger.debug(
f"Layer {layer.name} has been removed from the menu"
)

self._check_for_removed_layers()
self._update_tooltip()
self._update_description()

def _check_for_removed_layer(self, layer):
"""Check if a specific layer has been removed from the viewer and must be removed from the menu."""
if isinstance(layer, str):
name = layer
elif isinstance(layer, self.layer_type):
name = layer.name
else:
logger.warning("Layer is not a string or a valid napari layer")
return

if name in self._get_all_layers() and name not in [
l.name for l in self._viewer.layers
]:
index = self.layer_list.findText(name)
self.layer_list.removeItem(index)
logger.debug(f"Layer {name} has been removed from the menu")

def _check_for_removed_layers(self):
"""Check for layers that have been removed from the viewer and must be removed from the menu."""
for layer in self._get_all_layers():
self._check_for_removed_layer(layer)

def _update_tooltip(self):
self.layer_list.setToolTip(self.layer_list.currentText())

def _update_description(self):
try:
if self.layer_list.currentText() != "":
self.layer_description.setVisible(True)
shape_desc = f"Shape : {self.layer_data().shape}"
self.layer_description.setText(shape_desc)
try:
shape_desc = f"Shape : {self.layer_data().shape}"
self.layer_description.setText(shape_desc)
self.layer_description.setVisible(True)
except AttributeError:
self.layer_description.setVisible(False)
else:
self.layer_description.setVisible(False)
except KeyError:
Expand All @@ -841,6 +907,13 @@ def _add_layer(self, event):
if isinstance(inserted_layer, self.layer_type):
self.layer_list.addItem(inserted_layer.name)

# check for renaming
inserted_layer.events.name.connect(self._rename_layer)

def _rename_layer(self, _):
# on layer rename, check for removed/new layers
self._check_for_layers()

def _remove_layer(self, event):
removed_layer = event.value

Expand All @@ -867,15 +940,24 @@ def layer(self):

def layer_name(self):
"""Returns the name of the layer selected in the dropdown menu."""
return self.layer_list.currentText()
try:
return self.layer_list.currentText()
except (KeyError, ValueError):
logger.warning("Layer list is empty")
return None

def layer_data(self):
"""Returns the data of the layer selected in the dropdown menu."""
if self.layer_list.count() < 1:
logger.debug("Layer list is empty")
return None

return self.layer().data
try:
return self.layer().data
except (KeyError, ValueError):
msg = f"Layer {self.layer_name()} has no data. Layer might have been renamed or removed."
logger.warning(msg)
warn(msg, stacklevel=1)
return None


class FilePathWidget(QWidget): # TODO include load as folder
Expand Down
1 change: 1 addition & 0 deletions napari_cellseg3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
LOGGER = logging.getLogger(__name__)
###############
# Global logging level setting
# SET TO INFO FOR RELEASE
# LOGGER.setLevel(logging.DEBUG)
LOGGER.setLevel(logging.INFO)
###############
Expand Down

0 comments on commit 570666a

Please sign in to comment.