Skip to content

Commit

Permalink
feat: show both model ID and name for BioImage.IO models
Browse files Browse the repository at this point in the history
  • Loading branch information
qin-yu committed Dec 19, 2024
1 parent c27f68c commit 469aa6c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
25 changes: 14 additions & 11 deletions plantseg/core/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,23 +457,26 @@ def _is_plantseg_model(self, collection_entry: dict) -> bool:
normalized_tags = ["".join(filter(str.isalnum, tag.lower())) for tag in tags]
return 'plantseg' in normalized_tags

def get_bioimageio_zoo_plantseg_model_names(self) -> list[str]:
"""Return a list of model names in the BioImage.IO Model Zoo tagged with 'plantseg'."""
def get_bioimageio_zoo_all_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo."""
if not hasattr(self, 'models_bioimageio'):

Check warning on line 462 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L462

Added line #L462 was not covered by tests
self.refresh_bioimageio_zoo_urls()
return sorted(model_zoo.models_bioimageio[model_zoo.models_bioimageio["supported"]].index.to_list())
id_name = self.models_bioimageio[['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

Check warning on line 465 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L464-L465

Added lines #L464 - L465 were not covered by tests

def get_bioimageio_zoo_all_model_names(self) -> list[str]:
"""Return a list of all model names in the BioImage.IO Model Zoo."""
def get_bioimageio_zoo_plantseg_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo tagged with 'plantseg'."""
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()
return sorted(model_zoo.models_bioimageio.index.to_list())
id_name = self.models_bioimageio[self.models_bioimageio["supported"]][['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

def get_bioimageio_zoo_other_model_names(self) -> list[str]:
"""Return a list of model names in the BioImage.IO Model Zoo not tagged with 'plantseg'."""
return sorted(
list(set(self.get_bioimageio_zoo_all_model_names()) - set(self.get_bioimageio_zoo_plantseg_model_names()))
)
def get_bioimageio_zoo_other_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo not tagged with 'plantseg'."""
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()
id_name = self.models_bioimageio[~self.models_bioimageio["supported"]][['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

Check warning on line 479 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L476-L479

Added lines #L476 - L479 were not covered by tests

def _flatten_module(self, module: Module) -> list[Module]:
"""Recursively flatten a PyTorch nn.Module into a list of its elemental layers."""
Expand Down
34 changes: 17 additions & 17 deletions plantseg/viewer_napari/widgets/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def to_choices(cls):
'label': 'BioImage.IO model',
'tooltip': 'Select a model from BioImage.IO model zoo.',
'choices': model_zoo.get_bioimageio_zoo_plantseg_model_names(),
'value': model_zoo.get_bioimageio_zoo_plantseg_model_names()[0][1],
},
advanced={
'label': 'Show advanced parameters',
Expand All @@ -131,22 +132,23 @@ def widget_unet_prediction(
mode: UNetPredictionMode = UNetPredictionMode.PLANTSEG,
plantseg_filter: bool = True,
model_name: Optional[str] = None,
model_id: Optional[str] = None,
model_id: Optional[str] = model_zoo.get_bioimageio_zoo_plantseg_model_names()[0][1],
device: str = ALL_DEVICES[0],
advanced: bool = False,
patch_size: tuple[int, int, int] = (128, 128, 128),
patch_halo: tuple[int, int, int] = (0, 0, 0),
single_patch: bool = False,
) -> None:
ps_image = PlantSegImage.from_napari_layer(image)

Check warning on line 142 in plantseg/viewer_napari/widgets/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/viewer_napari/widgets/prediction.py#L142

Added line #L142 was not covered by tests
widgets_to_update = [
widget_dt_ws.image,
widget_agglomeration.image,
widget_split_and_merge_from_scribbles.image,
]

if mode is UNetPredictionMode.PLANTSEG:
suffix = model_name
model_id = None
widgets_to_update = [

Check warning on line 147 in plantseg/viewer_napari/widgets/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/viewer_napari/widgets/prediction.py#L147

Added line #L147 was not covered by tests
widget_dt_ws.image,
widget_agglomeration.image,
widget_split_and_merge_from_scribbles.image,
]
return schedule_task(

Check warning on line 152 in plantseg/viewer_napari/widgets/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/viewer_napari/widgets/prediction.py#L152

Added line #L152 was not covered by tests
unet_prediction_task,
task_kwargs={
Expand All @@ -164,6 +166,10 @@ def widget_unet_prediction(
elif mode is UNetPredictionMode.BIOIMAGEIO:
suffix = model_id
model_name = None
widgets_to_update = [

Check warning on line 169 in plantseg/viewer_napari/widgets/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/viewer_napari/widgets/prediction.py#L169

Added line #L169 was not covered by tests
# BioImage.IO models may output multi-channel 3D image or even multi-channel scalar in CZYX format.
# So PlantSeg widgets, which all take ZYX or YX, are better not to be updated.
]
return schedule_task(

Check warning on line 173 in plantseg/viewer_napari/widgets/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/viewer_napari/widgets/prediction.py#L173

Added line #L173 was not covered by tests
biio_prediction_task,
task_kwargs={
Expand Down Expand Up @@ -210,17 +216,11 @@ def update_halo():
widget_unet_prediction.patch_size[0].enabled = True
widget_unet_prediction.patch_halo[0].enabled = True
elif widget_unet_prediction.mode.value is UNetPredictionMode.BIOIMAGEIO:
widget_unet_prediction.patch_halo.value = model_zoo.compute_3D_halo_for_bioimageio_models(
widget_unet_prediction.model_id.value
log(

Check warning on line 219 in plantseg/viewer_napari/widgets/prediction.py

View check run for this annotation

Codecov / codecov/patch

plantseg/viewer_napari/widgets/prediction.py#L219

Added line #L219 was not covered by tests
'Automatic halo not implemented for BioImage.IO models yet because they are handled by BioImage.IO Core.',
thread='BioImage.IO Core prediction',
level='info',
)
if model_zoo.is_2D_bioimageio_model(widget_unet_prediction.model_id.value):
widget_unet_prediction.patch_size[0].value = 0
widget_unet_prediction.patch_size[0].enabled = False
widget_unet_prediction.patch_halo[0].enabled = False
else:
widget_unet_prediction.patch_size[0].value = widget_unet_prediction.patch_size[1].value
widget_unet_prediction.patch_size[0].enabled = True
widget_unet_prediction.patch_halo[0].enabled = True
else:
raise NotImplementedError(f'Automatic halo not implemented for {widget_unet_prediction.mode.value} mode.')

Expand Down Expand Up @@ -270,7 +270,7 @@ def _on_widget_unet_prediction_plantseg_filter_change(plantseg_filter: bool):
else:
widget_unet_prediction.model_id.choices = (
model_zoo.get_bioimageio_zoo_plantseg_model_names()
+ [Separator]
+ [('', Separator)] # `[('', Separator)]` for list[tuple[str, str]], [Separator] for list[str]
+ model_zoo.get_bioimageio_zoo_other_model_names()
)

Expand Down

0 comments on commit 469aa6c

Please sign in to comment.