Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from keras.api import tree
from keras.api import utils
from keras.api import version
from keras.api import visualization

# END DO NOT EDIT.

Expand Down
1 change: 1 addition & 0 deletions keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from keras.api import saving
from keras.api import tree
from keras.api import utils
from keras.api import visualization
from keras.src.backend import Variable
from keras.src.backend import device
from keras.src.backend import name_scope
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from keras.api import regularizers
from keras.api import tree
from keras.api import utils
from keras.api import visualization
from keras.api._tf_keras.keras import backend
from keras.api._tf_keras.keras import layers
from keras.api._tf_keras.keras import losses
Expand Down
17 changes: 17 additions & 0 deletions keras/api/_tf_keras/keras/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes
from keras.src.visualization.draw_segmentation_masks import (
draw_segmentation_masks,
)
from keras.src.visualization.plot_bounding_box_gallery import (
plot_bounding_box_gallery,
)
from keras.src.visualization.plot_image_gallery import plot_image_gallery
from keras.src.visualization.plot_segmentation_mask_gallery import (
plot_segmentation_mask_gallery,
)
17 changes: 17 additions & 0 deletions keras/api/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes
from keras.src.visualization.draw_segmentation_masks import (
draw_segmentation_masks,
)
from keras.src.visualization.plot_bounding_box_gallery import (
plot_bounding_box_gallery,
)
from keras.src.visualization.plot_image_gallery import plot_image_gallery
from keras.src.visualization.plot_segmentation_mask_gallery import (
plot_segmentation_mask_gallery,
)
1 change: 1 addition & 0 deletions keras/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras.src import optimizers
from keras.src import regularizers
from keras.src import utils
from keras.src import visualization
from keras.src.backend import KerasTensor
from keras.src.layers import Input
from keras.src.layers import Layer
Expand Down
2 changes: 2 additions & 0 deletions keras/src/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from keras.src.visualization import draw_bounding_boxes
from keras.src.visualization import plot_image_gallery
177 changes: 177 additions & 0 deletions keras/src/visualization/draw_bounding_boxes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
convert_format,
)

try:
import cv2
except ImportError:
cv2 = None


@keras_export("keras.visualization.draw_bounding_boxes")
def draw_bounding_boxes(
images,
bounding_boxes,
bounding_box_format,
class_mapping=None,
color=(128, 128, 128),
line_thickness=2,
text_thickness=1,
font_scale=1.0,
data_format=None,
):
"""Draws bounding boxes on images.

This function draws bounding boxes on a batch of images. It supports
different bounding box formats and can optionally display class labels
and confidences.

Args:
images: A batch of images as a 4D tensor or NumPy array. Shape should be
`(batch_size, height, width, channels)`.
bounding_boxes: A dictionary containing bounding box data. Should have
the following keys:
- `boxes`: A tensor or array of shape `(batch_size, num_boxes, 4)`
containing the bounding box coordinates in the specified format.
- `labels`: A tensor or array of shape `(batch_size, num_boxes)`
containing the class labels for each bounding box.
- `confidences` (Optional): A tensor or array of shape
`(batch_size, num_boxes)` containing the confidence scores for
each bounding box.
bounding_box_format: A string specifying the format of the bounding
boxes. Refer [keras-io](TODO)
class_mapping: A dictionary mapping class IDs (integers) to class labels
(strings). Used to display class labels next to the bounding boxes.
Defaults to None (no labels displayed).
color: A tuple or list representing the RGB color of the bounding boxes.
For example, `(255, 0, 0)` for red. Defaults to `(128, 128, 128)`.
line_thickness: An integer specifying the thickness of the bounding box
lines. Defaults to `2`.
text_thickness: An integer specifying the thickness of the text labels.
Defaults to `1`.
font_scale: A float specifying the scale of the font used for text
labels. Defaults to `1.0`.
data_format: A string, either `"channels_last"` or `"channels_first"`,
specifying the order of dimensions in the input images. Defaults to
the `image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
"channels_last".

Returns:
A NumPy array of the annotated images with the bounding boxes drawn.
The array will have the same shape as the input `images`.

Raises:
ValueError: If `images` is not a 4D tensor/array, if `bounding_boxes` is
not a dictionary, or if `bounding_boxes` does not contain `"boxes"`
and `"labels"` keys.
TypeError: If `bounding_boxes` is not a dictionary.
ImportError: If `cv2` (OpenCV) is not installed.
"""

if cv2 is None:
raise ImportError(
"The `draw_bounding_boxes` function requires the `cv2` package "
" (OpenCV). Please install it with `pip install opencv-python`."
)

class_mapping = class_mapping or {}
text_thickness = (
text_thickness or line_thickness
) # Default text_thickness if not provided.
data_format = data_format or backend.image_data_format()
images_shape = ops.shape(images)
if len(images_shape) != 4:
raise ValueError(
"`images` must be batched 4D tensor. "
f"Received: images.shape={images_shape}"
)
if not isinstance(bounding_boxes, dict):
raise TypeError(
"`bounding_boxes` should be a dict. "
f"Received: bounding_boxes={bounding_boxes} of type "
f"{type(bounding_boxes)}"
)
if "boxes" not in bounding_boxes or "labels" not in bounding_boxes:
raise ValueError(
"`bounding_boxes` should be a dict containing 'boxes' and "
f"'labels' keys. Received: bounding_boxes={bounding_boxes}"
)
if data_format == "channels_last":
h_axis = -3
w_axis = -2
else:
h_axis = -2
w_axis = -1
height = images_shape[h_axis]
width = images_shape[w_axis]
bounding_boxes = bounding_boxes.copy()
bounding_boxes = convert_format(
bounding_boxes, bounding_box_format, "xyxy", height, width
)

# To numpy array
images = ops.convert_to_numpy(images).astype("uint8")
boxes = ops.convert_to_numpy(bounding_boxes["boxes"])
labels = ops.convert_to_numpy(bounding_boxes["labels"])
if "confidences" in bounding_boxes:
confidences = ops.convert_to_numpy(bounding_boxes["confidences"])
else:
confidences = None

result = []
batch_size = images.shape[0]
for i in range(batch_size):
_image = images[i]
_box = boxes[i]
_class = labels[i]
for box_i in range(_box.shape[0]):
x1, y1, x2, y2 = _box[box_i].astype("int32")
c = _class[box_i].astype("int32")
if c == -1:
continue
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
c = int(c)
# Draw bounding box
cv2.rectangle(_image, (x1, y1), (x2, y2), color, line_thickness)

if c in class_mapping:
label = class_mapping[c]
if confidences is not None:
conf = confidences[i][box_i]
label = f"{label} | {conf:.2f}"

font_x1, font_y1 = _find_text_location(
x1, y1, font_scale, text_thickness
)
cv2.putText(
img=_image,
text=label,
org=(font_x1, font_y1),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_scale,
color=color,
thickness=text_thickness,
)
result.append(_image)
return np.stack(result, axis=0)


def _find_text_location(x, y, font_scale, thickness):
font_height = int(font_scale * 12)
target_y = y - 8
if target_y - (2 * font_height) > 0:
return x, y - 8

line_offset = thickness
static_offset = 3

return (
x + static_offset,
y + (2 * font_height) + line_offset + static_offset,
)
109 changes: 109 additions & 0 deletions keras/src/visualization/draw_segmentation_masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export


@keras_export("keras.visualization.draw_segmentation_masks")
def draw_segmentation_masks(
images,
segmentation_masks,
num_classes=None,
color_mapping=None,
alpha=0.8,
blend=True,
ignore_index=-1,
data_format=None,
):
"""Draws segmentation masks on images.

The function overlays segmentation masks on the input images.
The masks are blended with the images using the specified alpha value.

Args:
images: A batch of images as a 4D tensor or NumPy array. Shape
should be (batch_size, height, width, channels).
segmentation_masks: A batch of segmentation masks as a 3D or 4D tensor
or NumPy array. Shape should be (batch_size, height, width) or
(batch_size, height, width, 1). The values represent class indices
starting from 1 up to `num_classes`. Class 0 is reserved for
the background and will be ignored if `ignore_index` is not 0.
num_classes: The number of segmentation classes. If `None`, it is
inferred from the maximum value in `segmentation_masks`.
color_mapping: A dictionary mapping class indices to RGB colors.
If `None`, a default color palette is generated. The keys should be
integers starting from 1 up to `num_classes`.
alpha: The opacity of the segmentation masks. Must be in the range
`[0, 1]`.
blend: Whether to blend the masks with the input image using the
`alpha` value. If `False`, the masks are drawn directly on the
images without blending. Defaults to `True`.
ignore_index: The class index to ignore. Mask pixels with this value
will not be drawn. Defaults to -1.
data_format: Image data format, either `"channels_last"` or
`"channels_first"`. Defaults to the `image_data_format` value found
in your Keras config file at `~/.keras/keras.json`. If you never
set it, then it will be `"channels_last"`.

Returns:
A NumPy array of the images with the segmentation masks overlaid.

Raises:
ValueError: If the input `images` is not a 4D tensor or NumPy array.
TypeError: If the input `segmentation_masks` is not an integer type.
"""
data_format = data_format or backend.image_data_format()
images_shape = ops.shape(images)
if len(images_shape) != 4:
raise ValueError(
"`images` must be batched 4D tensor. "
f"Received: images.shape={images_shape}"
)
if data_format == "channels_first":
images = ops.transpose(images, (0, 2, 3, 1))
segmentation_masks = ops.transpose(segmentation_masks, (0, 2, 3, 1))
images = ops.convert_to_tensor(images, dtype="float32")
segmentation_masks = ops.convert_to_tensor(segmentation_masks)

if not backend.is_int_dtype(segmentation_masks.dtype):
dtype = backend.standardize_dtype(segmentation_masks.dtype)
raise TypeError(
"`segmentation_masks` must be in integer dtype. "
f"Received: segmentation_masks.dtype={dtype}"
)

# Infer num_classes
if num_classes is None:
num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks)))
if color_mapping is None:
colors = _generate_color_palette(num_classes)
else:
colors = [color_mapping[i] for i in range(num_classes)]
valid_masks = ops.not_equal(segmentation_masks, ignore_index)
valid_masks = ops.squeeze(valid_masks, axis=-1)
segmentation_masks = ops.one_hot(segmentation_masks, num_classes)
segmentation_masks = segmentation_masks[..., 0, :]
segmentation_masks = ops.convert_to_numpy(segmentation_masks)

# Replace class with color
masks = segmentation_masks
masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype("bool")
images_to_draw = ops.convert_to_numpy(images).copy()
for mask, color in zip(masks, colors):
color = np.array(color, dtype=images_to_draw.dtype)
images_to_draw[mask, ...] = color[None, :]
images_to_draw = ops.convert_to_tensor(images_to_draw)
outputs = ops.cast(images_to_draw, dtype="float32")

if blend:
outputs = images * (1 - alpha) + outputs * alpha
outputs = ops.where(valid_masks[..., None], outputs, images)
outputs = ops.cast(outputs, dtype="uint8")
outputs = ops.convert_to_numpy(outputs)
return outputs


def _generate_color_palette(num_classes: int):
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
return [((i * palette) % 255).tolist() for i in range(num_classes)]
Loading