Skip to content

Commit

Permalink
Added title info to visualization with optional show_id arg (#211)
Browse files Browse the repository at this point in the history
* Added title info to visualization with optional show_id arg

* Fixed tests

* Fixed type annotations

* Added docstring

* Update src/cleanvision/imagelab.py

Co-authored-by: Jonas Mueller <[email protected]>

---------

Co-authored-by: Jonas Mueller <[email protected]>
  • Loading branch information
sanjanag and jwmueller authored Jul 17, 2023
1 parent 28c15f0 commit ff59d69
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 72 deletions.
50 changes: 29 additions & 21 deletions src/cleanvision/imagelab.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def report(
num_images: Optional[int] = None,
verbosity: int = 1,
print_summary: bool = True,
show_id: bool = False,
) -> None:
"""Prints summary of the issues found in your dataset.
By default, this method depicts the images representing top-most severe instances of each issue type.
Expand All @@ -391,6 +392,9 @@ def report(
print_summary : bool, default=True
If True, prints the summary of issues found in the dataset.
show_id: bool, default=False
If True, prints the dataset ID of each image shown in the report.
Examples
--------
Default usage
Expand Down Expand Up @@ -446,6 +450,7 @@ def report(
issue_type,
report_args["num_images"],
report_args["cell_size"],
show_id,
)
else:
print(
Expand All @@ -471,6 +476,7 @@ def _visualize(
issue_type: str,
num_images: int,
cell_size: Tuple[int, int],
show_id: bool,
) -> None:
# todo: remove dependency on issue manager
issue_manager = self._get_issue_manager(issue_type)
Expand All @@ -484,22 +490,21 @@ def _visualize(
indices = scores.index.tolist()
images = [self._dataset[i] for i in indices]

titles = [f"score : {x:.4f}" for x in scores]

# Add size information for odd sized images
additional_info = None
# construct title info
title_info = {"scores": [f"score : {x:.4f}" for x in scores]}
if show_id:
title_info["ids"] = [f"id : {i}" for i in indices]
if issue_type == IssueType.ODD_SIZE.value:
additional_info = []
for image in images:
additional_info.append(f"original size: {image.size}")
title_info["size"] = [
f"original size: {image.size}" for image in images
]

if images:
VizManager.individual_images(
images=images,
titles=titles,
title_info=title_info,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
additional_info=additional_info,
)

elif viz_name == "image_sets":
Expand All @@ -511,15 +516,15 @@ def _visualize(
for indices in image_sets_indices:
image_sets.append([self._dataset[index] for index in indices])

title_sets = [
[self._dataset.get_name(index) for index in s]
for s in image_sets_indices
]
title_info_sets = []
for s in image_sets_indices:
title_info = {"name": [self._dataset.get_name(index) for index in s]}
title_info_sets.append(title_info)

if image_sets:
VizManager.image_sets(
image_sets,
title_sets,
title_info_sets,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
)
Expand All @@ -532,6 +537,7 @@ def visualize(
issue_types: Optional[List[str]] = None,
num_images: int = 4,
cell_size: Tuple[int, int] = (2, 2),
show_id: bool = False,
) -> None:
"""Show specific images.
Expand Down Expand Up @@ -599,24 +605,24 @@ def visualize(
if len(issue_types) == 0:
raise ValueError("issue_types list is empty")
for issue_type in issue_types:
self._visualize(issue_type, num_images, cell_size)
self._visualize(issue_type, num_images, cell_size, show_id)
elif image_files is not None:
if len(image_files) == 0:
raise ValueError("image_files list is empty.")
images = [Image.open(path) for path in image_files]
titles = [path.split("/")[-1] for path in image_files]
title_info = {"path": [path.split("/")[-1] for path in image_files]}
VizManager.individual_images(
images,
titles,
title_info,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
)
elif indices:
images = [self._dataset[i] for i in indices]
titles = [self._dataset.get_name(i) for i in indices]
title_info = {"name": [self._dataset.get_name(i) for i in indices]}
VizManager.individual_images(
images,
titles,
title_info,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
)
Expand All @@ -628,10 +634,12 @@ def visualize(
self._dataset.index, min(num_images, len(self._dataset))
)
images = [self._dataset[i] for i in image_indices]
titles = [self._dataset.get_name(i) for i in image_indices]
title_info = {
"name": [self._dataset.get_name(i) for i in image_indices]
}
VizManager.individual_images(
images,
titles,
title_info,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
)
Expand Down
96 changes: 55 additions & 41 deletions src/cleanvision/utils/viz_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Optional
from typing import List, Tuple, Dict

import math
import matplotlib.axes
Expand All @@ -10,24 +10,23 @@ class VizManager:
@staticmethod
def individual_images(
images: List[Image.Image],
titles: List[str],
title_info: Dict[str, List[str]],
ncols: int,
cell_size: Tuple[int, int],
additional_info: Optional[List[str]] = None,
) -> None:
"""Plots a list of images in a grid."""
plot_image_grid(images, titles, ncols, cell_size, additional_info)
plot_image_grid(images, title_info, ncols, cell_size)

@staticmethod
def image_sets(
image_sets: List[List[Image.Image]],
title_sets: List[List[str]],
title_info_sets: List[Dict[str, List[str]]],
ncols: int,
cell_size: Tuple[int, int],
) -> None:
for i, s in enumerate(image_sets):
print(f"Set: {i}")
plot_image_grid(s, title_sets[i], ncols, cell_size)
plot_image_grid(s, title_info_sets[i], ncols, cell_size)


def set_image_on_axes(image: Image.Image, ax: matplotlib.axes.Axes, title: str) -> None:
Expand All @@ -38,51 +37,66 @@ def set_image_on_axes(image: Image.Image, ax: matplotlib.axes.Axes, title: str)
ax.imshow(image, cmap=cmap, vmin=0, vmax=255)


def truncate_titles(cell_width: int, titles: List[str]) -> List[str]:
"""Converts font size of 7 into inches"""
CHARACTER_SIZE_INCHES = 7 * (1 / 72)

chars_allowed = math.ceil(cell_width / CHARACTER_SIZE_INCHES) - 4

k1 = 1
while k1 <= chars_allowed and titles[0][:k1] == titles[1][:k1]:
k1 += 1
k2 = 1
while (
k2 <= chars_allowed
and titles[0][(len(titles[0]) - k2) :] == titles[1][(len(titles[1]) - k2) :]
):
k2 += 1

if k1 > k2:
truncate_from_front = True
else:
truncate_from_front = False

for i in range(len(titles)):
title_width = len(titles[i]) * CHARACTER_SIZE_INCHES
if title_width >= cell_width:
titles[i] = (
("..." + titles[i][len(titles[i]) - chars_allowed :])
if truncate_from_front
else (titles[i][:chars_allowed] + "...")
)
return titles


def construct_titles(title_info: Dict[str, List[str]], cell_width: int) -> List[str]:
keys = list(title_info.keys())
nimages = len(title_info[keys[0]])

# truncate longer lines
if nimages > 1:
for key in keys:
title_info[key] = truncate_titles(cell_width, title_info[key])

# join all keys
titles = []
for i in range(nimages):
titles.append("\n".join(title_info[key][i] for key in keys))
return titles


def plot_image_grid(
images: List[Image.Image],
titles: List[str],
title_info: Dict[str, List[str]],
ncols: int,
cell_size: Tuple[int, int],
additional_info: Optional[List[str]] = None,
) -> None:
nrows = math.ceil(len(images) / ncols)
ncols = min(ncols, len(images))
fig, axes = plt.subplots(
nrows, ncols, figsize=(cell_size[0] * ncols, cell_size[1] * nrows)
)

"""Converts font size of 7 into inches"""
CHARACTER_SIZE_INCHES = 7 * (1 / 72)

chars_allowed = math.ceil(cell_size[0] / CHARACTER_SIZE_INCHES) - 4

if len(images) > 1:
k1 = 1
while k1 <= chars_allowed and titles[0][:k1] == titles[1][:k1]:
k1 += 1
k2 = 1
while (
k2 <= chars_allowed
and titles[0][(len(titles[0]) - k2) :] == titles[1][(len(titles[1]) - k2) :]
):
k2 += 1

if k1 > k2:
truncate_from_front = True
else:
truncate_from_front = False

for i in range(len(images)):
title_width = len(titles[i]) * CHARACTER_SIZE_INCHES
if title_width >= cell_size[0]:
titles[i] = (
("..." + titles[i][len(titles[i]) - chars_allowed :])
if truncate_from_front
else (titles[i][:chars_allowed] + "...")
)
if additional_info is not None:
for i in range(len(images)):
titles[i] = f"{titles[i]}\n{additional_info[i]}"
titles = construct_titles(title_info, cell_size[0])
if nrows > 1:
idx = 0
for i in range(nrows):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@
class TestVizManager:
@pytest.mark.usefixtures("set_plt_show")
@pytest.mark.parametrize(
("images", "titles"),
("images", "title_info"),
[
([Image.new("L", (100, 100))], ["image_title"]),
([Image.new("L", (100, 100))] * 2, ["image_title"] * 4),
([Image.new("L", (100, 100))] * 6, ["imaxge_title"] * 6),
([Image.new("L", (100, 100))], {"name": ["image_title"]}),
([Image.new("L", (100, 100))] * 2, {"name": ["image_title"] * 4}),
([Image.new("L", (100, 100))] * 6, {"name": ["imaxge_title"] * 6}),
],
ids=["plot single image", "plot <=4 images", "plt > 4 images"],
)
def test_individual_images(self, images, titles):
VizManager.individual_images(images, titles, 4, (2, 2))
def test_individual_images(self, images, title_info):
VizManager.individual_images(images, title_info, 4, (2, 2))

@pytest.mark.usefixtures("set_plt_show")
@pytest.mark.parametrize(
("image_sets", "title_sets"),
("image_sets", "title_info_sets"),
[
(
[[Image.new("L", (100, 100))], [Image.new("L", (100, 100))] * 2],
[["image_title"], ["image_title"] * 2],
[{"name": ["image_title"]}, {"name": ["image_title"] * 2}],
),
],
)
def test_image_sets(self, image_sets, title_sets):
VizManager.image_sets(image_sets, title_sets, 4, (2, 2))
def test_image_sets(self, image_sets, title_info_sets):
VizManager.image_sets(image_sets, title_info_sets, 4, (2, 2))

0 comments on commit ff59d69

Please sign in to comment.