Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions src/cleanvision/imagelab.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def report(
num_images: Optional[int] = None,
verbosity: int = 1,
print_summary: bool = True,
show_id: bool = False,
) -> None:
"""Prints summary of the issues found in your dataset.
By default, this method depicts the images representing top-most severe instances of each issue type.
Expand All @@ -391,6 +392,9 @@ def report(
print_summary : bool, default=True
If True, prints the summary of issues found in the dataset.

show_id: bool, default=False
If True, prints the dataset ID of each image shown in the report.

Examples
--------
Default usage
Expand Down Expand Up @@ -446,6 +450,7 @@ def report(
issue_type,
report_args["num_images"],
report_args["cell_size"],
show_id,
)
else:
print(
Expand All @@ -471,6 +476,7 @@ def _visualize(
issue_type: str,
num_images: int,
cell_size: Tuple[int, int],
show_id: bool,
) -> None:
# todo: remove dependency on issue manager
issue_manager = self._get_issue_manager(issue_type)
Expand All @@ -484,22 +490,21 @@ def _visualize(
indices = scores.index.tolist()
images = [self._dataset[i] for i in indices]

titles = [f"score : {x:.4f}" for x in scores]

# Add size information for odd sized images
additional_info = None
# construct title info
title_info = {"scores": [f"score : {x:.4f}" for x in scores]}
if show_id:
title_info["ids"] = [f"id : {i}" for i in indices]
if issue_type == IssueType.ODD_SIZE.value:
additional_info = []
for image in images:
additional_info.append(f"original size: {image.size}")
title_info["size"] = [
f"original size: {image.size}" for image in images
]

if images:
VizManager.individual_images(
images=images,
titles=titles,
title_info=title_info,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
additional_info=additional_info,
)

elif viz_name == "image_sets":
Expand All @@ -511,15 +516,15 @@ def _visualize(
for indices in image_sets_indices:
image_sets.append([self._dataset[index] for index in indices])

title_sets = [
[self._dataset.get_name(index) for index in s]
for s in image_sets_indices
]
title_info_sets = []
for s in image_sets_indices:
title_info = {"name": [self._dataset.get_name(index) for index in s]}
title_info_sets.append(title_info)

if image_sets:
VizManager.image_sets(
image_sets,
title_sets,
title_info_sets,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
)
Expand All @@ -532,6 +537,7 @@ def visualize(
issue_types: Optional[List[str]] = None,
num_images: int = 4,
cell_size: Tuple[int, int] = (2, 2),
show_id: bool = False,
) -> None:
"""Show specific images.

Expand Down Expand Up @@ -599,24 +605,24 @@ def visualize(
if len(issue_types) == 0:
raise ValueError("issue_types list is empty")
for issue_type in issue_types:
self._visualize(issue_type, num_images, cell_size)
self._visualize(issue_type, num_images, cell_size, show_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to hit this line in unit tests if not too hard

elif image_files is not None:
if len(image_files) == 0:
raise ValueError("image_files list is empty.")
images = [Image.open(path) for path in image_files]
titles = [path.split("/")[-1] for path in image_files]
title_info = {"path": [path.split("/")[-1] for path in image_files]}
VizManager.individual_images(
images,
titles,
title_info,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
)
elif indices:
images = [self._dataset[i] for i in indices]
titles = [self._dataset.get_name(i) for i in indices]
title_info = {"name": [self._dataset.get_name(i) for i in indices]}
VizManager.individual_images(
images,
titles,
title_info,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
)
Expand All @@ -628,10 +634,12 @@ def visualize(
self._dataset.index, min(num_images, len(self._dataset))
)
images = [self._dataset[i] for i in image_indices]
titles = [self._dataset.get_name(i) for i in image_indices]
title_info = {
"name": [self._dataset.get_name(i) for i in image_indices]
}
VizManager.individual_images(
images,
titles,
title_info,
ncols=self._config["visualize_num_images_per_row"],
cell_size=cell_size,
)
Expand Down
96 changes: 55 additions & 41 deletions src/cleanvision/utils/viz_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Optional
from typing import List, Tuple, Dict

import math
import matplotlib.axes
Expand All @@ -10,24 +10,23 @@ class VizManager:
@staticmethod
def individual_images(
images: List[Image.Image],
titles: List[str],
title_info: Dict[str, List[str]],
ncols: int,
cell_size: Tuple[int, int],
additional_info: Optional[List[str]] = None,
) -> None:
"""Plots a list of images in a grid."""
plot_image_grid(images, titles, ncols, cell_size, additional_info)
plot_image_grid(images, title_info, ncols, cell_size)

@staticmethod
def image_sets(
image_sets: List[List[Image.Image]],
title_sets: List[List[str]],
title_info_sets: List[Dict[str, List[str]]],
ncols: int,
cell_size: Tuple[int, int],
) -> None:
for i, s in enumerate(image_sets):
print(f"Set: {i}")
plot_image_grid(s, title_sets[i], ncols, cell_size)
plot_image_grid(s, title_info_sets[i], ncols, cell_size)


def set_image_on_axes(image: Image.Image, ax: matplotlib.axes.Axes, title: str) -> None:
Expand All @@ -38,51 +37,66 @@ def set_image_on_axes(image: Image.Image, ax: matplotlib.axes.Axes, title: str)
ax.imshow(image, cmap=cmap, vmin=0, vmax=255)


def truncate_titles(cell_width: int, titles: List[str]) -> List[str]:
"""Converts font size of 7 into inches"""
CHARACTER_SIZE_INCHES = 7 * (1 / 72)

chars_allowed = math.ceil(cell_width / CHARACTER_SIZE_INCHES) - 4

k1 = 1
while k1 <= chars_allowed and titles[0][:k1] == titles[1][:k1]:
k1 += 1
k2 = 1
while (
k2 <= chars_allowed
and titles[0][(len(titles[0]) - k2) :] == titles[1][(len(titles[1]) - k2) :]
):
k2 += 1

if k1 > k2:
truncate_from_front = True
else:
truncate_from_front = False

for i in range(len(titles)):
title_width = len(titles[i]) * CHARACTER_SIZE_INCHES
if title_width >= cell_width:
titles[i] = (
("..." + titles[i][len(titles[i]) - chars_allowed :])
if truncate_from_front
else (titles[i][:chars_allowed] + "...")
)
return titles


def construct_titles(title_info: Dict[str, List[str]], cell_width: int) -> List[str]:
keys = list(title_info.keys())
nimages = len(title_info[keys[0]])

# truncate longer lines
if nimages > 1:
for key in keys:
title_info[key] = truncate_titles(cell_width, title_info[key])

# join all keys
titles = []
for i in range(nimages):
titles.append("\n".join(title_info[key][i] for key in keys))
return titles


def plot_image_grid(
images: List[Image.Image],
titles: List[str],
title_info: Dict[str, List[str]],
ncols: int,
cell_size: Tuple[int, int],
additional_info: Optional[List[str]] = None,
) -> None:
nrows = math.ceil(len(images) / ncols)
ncols = min(ncols, len(images))
fig, axes = plt.subplots(
nrows, ncols, figsize=(cell_size[0] * ncols, cell_size[1] * nrows)
)

"""Converts font size of 7 into inches"""
CHARACTER_SIZE_INCHES = 7 * (1 / 72)

chars_allowed = math.ceil(cell_size[0] / CHARACTER_SIZE_INCHES) - 4

if len(images) > 1:
k1 = 1
while k1 <= chars_allowed and titles[0][:k1] == titles[1][:k1]:
k1 += 1
k2 = 1
while (
k2 <= chars_allowed
and titles[0][(len(titles[0]) - k2) :] == titles[1][(len(titles[1]) - k2) :]
):
k2 += 1

if k1 > k2:
truncate_from_front = True
else:
truncate_from_front = False

for i in range(len(images)):
title_width = len(titles[i]) * CHARACTER_SIZE_INCHES
if title_width >= cell_size[0]:
titles[i] = (
("..." + titles[i][len(titles[i]) - chars_allowed :])
if truncate_from_front
else (titles[i][:chars_allowed] + "...")
)
if additional_info is not None:
for i in range(len(images)):
titles[i] = f"{titles[i]}\n{additional_info[i]}"
titles = construct_titles(title_info, cell_size[0])
if nrows > 1:
idx = 0
for i in range(nrows):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@
class TestVizManager:
@pytest.mark.usefixtures("set_plt_show")
@pytest.mark.parametrize(
("images", "titles"),
("images", "title_info"),
[
([Image.new("L", (100, 100))], ["image_title"]),
([Image.new("L", (100, 100))] * 2, ["image_title"] * 4),
([Image.new("L", (100, 100))] * 6, ["imaxge_title"] * 6),
([Image.new("L", (100, 100))], {"name": ["image_title"]}),
([Image.new("L", (100, 100))] * 2, {"name": ["image_title"] * 4}),
([Image.new("L", (100, 100))] * 6, {"name": ["imaxge_title"] * 6}),
],
ids=["plot single image", "plot <=4 images", "plt > 4 images"],
)
def test_individual_images(self, images, titles):
VizManager.individual_images(images, titles, 4, (2, 2))
def test_individual_images(self, images, title_info):
VizManager.individual_images(images, title_info, 4, (2, 2))

@pytest.mark.usefixtures("set_plt_show")
@pytest.mark.parametrize(
("image_sets", "title_sets"),
("image_sets", "title_info_sets"),
[
(
[[Image.new("L", (100, 100))], [Image.new("L", (100, 100))] * 2],
[["image_title"], ["image_title"] * 2],
[{"name": ["image_title"]}, {"name": ["image_title"] * 2}],
),
],
)
def test_image_sets(self, image_sets, title_sets):
VizManager.image_sets(image_sets, title_sets, 4, (2, 2))
def test_image_sets(self, image_sets, title_info_sets):
VizManager.image_sets(image_sets, title_info_sets, 4, (2, 2))