Skip to content

Commit

Permalink
Merge pull request #672 from roboflow/feature/masked_crop
Browse files Browse the repository at this point in the history
Masked crop block
  • Loading branch information
PawelPeczek-Roboflow authored Sep 26, 2024
2 parents 52b16cc + f53fa69 commit 981dc55
Show file tree
Hide file tree
Showing 3 changed files with 450 additions and 17 deletions.
126 changes: 114 additions & 12 deletions inference/core/workflows/core_steps/transformations/dynamic_crop/v1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import replace
from typing import Dict, List, Literal, Optional, Type, Union
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

import cv2
import numpy as np
import supervision as sv
from pydantic import AliasChoices, ConfigDict, Field

Expand All @@ -13,13 +15,17 @@
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
FLOAT_ZERO_TO_ONE_KIND,
IMAGE_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
KEYPOINT_DETECTION_PREDICTION_KIND,
OBJECT_DETECTION_PREDICTION_KIND,
RGB_COLOR_KIND,
STRING_KIND,
StepOutputImageSelector,
StepOutputSelector,
WorkflowImageSelector,
WorkflowParameterSelector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
Expand All @@ -34,6 +40,11 @@
workflow. For example, you could use an ObjectDetection block to detect objects, then
the DynamicCropBlock block to crop objects, then an OCR block to run character recognition on
each of the individual cropped regions.
In addition, for instance segmentation predictions (which provide segmentation mask for each
bounding box) it is possible to remove background in the crops, outside of detected instances.
To enable that functionality, set `mask_opacity` to positive value and optionally tune
`background_color`.
"""


Expand Down Expand Up @@ -67,6 +78,36 @@ class BlockManifest(WorkflowBlockManifest):
examples=["$steps.my_object_detection_model.predictions"],
validation_alias=AliasChoices("predictions", "detections"),
)
mask_opacity: Union[
WorkflowParameterSelector(kind=[FLOAT_ZERO_TO_ONE_KIND]),
float,
] = Field(
default=0.0,
le=1.0,
ge=0.0,
description="For instance segmentation, mask_opacity can be used to control background removal. "
"Opacity 1.0 removes the background, while 0.0 leaves the crop unchanged.",
json_schema_extra={
"relevant_for": {
"predictions": {
"kind": [INSTANCE_SEGMENTATION_PREDICTION_KIND.name],
"required": True,
},
}
},
)
background_color: Union[
WorkflowParameterSelector(kind=[STRING_KIND]),
StepOutputSelector(kind=[RGB_COLOR_KIND]),
str,
Tuple[int, int, int],
] = Field(
default=(0, 0, 0),
description="For background removal based on segmentation mask, new background color can be selected. "
"Can be a hex string (like '#431112') RGB string (like '(128, 32, 64)') or a RGB tuple "
"(like (18, 17, 67)).",
examples=["#431112", "$inputs.bg_color", (18, 17, 67)],
)

@classmethod
def accepts_batch_input(cls) -> bool:
Expand Down Expand Up @@ -97,16 +138,25 @@ def run(
self,
images: Batch[WorkflowImageData],
predictions: Batch[sv.Detections],
mask_opacity: float,
background_color: Union[str, Tuple[int, int, int]],
) -> BlockResult:
return [
crop_image(image=image, detections=detections)
crop_image(
image=image,
detections=detections,
mask_opacity=mask_opacity,
background_color=background_color,
)
for image, detections in zip(images, predictions)
]


def crop_image(
image: WorkflowImageData,
detections: sv.Detections,
mask_opacity: float,
background_color: Union[str, Tuple[int, int, int]],
detection_id_key: str = DETECTION_ID_KEY,
) -> List[Dict[str, WorkflowImageData]]:
if len(detections) == 0:
Expand All @@ -117,10 +167,24 @@ def crop_image(
f"in data dictionary."
)
crops = []
for (x_min, y_min, x_max, y_max), detection_id in zip(
detections.xyxy.round().astype(dtype=int), detections[detection_id_key]
for idx, ((x_min, y_min, x_max, y_max), detection_id) in enumerate(
zip(detections.xyxy.round().astype(dtype=int), detections[detection_id_key])
):
cropped_image = image.numpy_image[y_min:y_max, x_min:x_max]
if not cropped_image.size:
crops.append({"crops": None})
continue
if mask_opacity > 0 and detections.mask is not None:
detection_mask = detections.mask[idx]
cropped_mask = np.stack(
[detection_mask[y_min:y_max, x_min:x_max]] * 3, axis=-1
)
cropped_image = overlay_crop_with_mask(
crop=cropped_image,
mask=cropped_mask,
mask_opacity=mask_opacity,
background_color=background_color,
)
parent_metadata = ImageParentMetadata(
parent_id=detection_id,
origin_coordinates=OriginCoordinatesSystem(
Expand All @@ -141,13 +205,51 @@ def crop_image(
parent_id=image.workflow_root_ancestor_metadata.parent_id,
origin_coordinates=workflow_root_ancestor_coordinates,
)
if cropped_image.size:
result = WorkflowImageData(
parent_metadata=parent_metadata,
workflow_root_ancestor_metadata=workflow_root_ancestor_metadata,
numpy_image=cropped_image,
)
else:
result = None
result = WorkflowImageData(
parent_metadata=parent_metadata,
workflow_root_ancestor_metadata=workflow_root_ancestor_metadata,
numpy_image=cropped_image,
)
crops.append({"crops": result})
return crops


def overlay_crop_with_mask(
crop: np.ndarray,
mask: np.ndarray,
mask_opacity: float,
background_color: Union[str, Tuple[int, int, int]],
) -> np.ndarray:
bgr_color = convert_color_to_bgr_tuple(color=background_color)
background = (np.ones_like(crop) * bgr_color).astype(np.uint8)
blended_crop = np.where(mask > 0, crop, background)
return cv2.addWeighted(blended_crop, mask_opacity, crop, 1.0 - mask_opacity, 0)


def convert_color_to_bgr_tuple(
color: Union[str, Tuple[int, int, int]]
) -> Tuple[int, int, int]:
if isinstance(color, str):
return convert_string_color_to_bgr_tuple(color=color)
if isinstance(color, tuple) and len(color) == 3:
return color[::-1]
raise ValueError(f"Invalid color format: {color}")


def convert_string_color_to_bgr_tuple(color: str) -> Tuple[int, int, int]:
if color.startswith("#") and len(color) == 7:
try:
return tuple(int(color[i : i + 2], 16) for i in (5, 3, 1))
except ValueError as e:
raise ValueError(f"Invalid hex color format: {color}") from e
if color.startswith("#") and len(color) == 4:
try:
return tuple(int(color[i] + color[i], 16) for i in (3, 2, 1))
except ValueError as e:
raise ValueError(f"Invalid hex color format: {color}") from e
if color.startswith("(") and color.endswith(")"):
try:
return tuple(map(int, color[1:-1].split(",")))[::-1]
except ValueError as e:
raise ValueError(f"Invalid tuple color format: {color}") from e
raise ValueError(f"Invalid hex color format: {color}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import numpy as np

from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.execution_engine.core import ExecutionEngine
from tests.workflows.integration_tests.execution.workflows_gallery_collector.decorators import (
add_to_workflows_gallery,
)

MASKED_CROP_WORKFLOW = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
{
"type": "WorkflowParameter",
"name": "model_id",
"default_value": "yolov8n-seg-640",
},
{
"type": "WorkflowParameter",
"name": "confidence",
"default_value": 0.4,
},
],
"steps": [
{
"type": "roboflow_core/roboflow_instance_segmentation_model@v1",
"name": "segmentation",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"confidence": "$inputs.confidence",
},
{
"type": "roboflow_core/dynamic_crop@v1",
"name": "cropping",
"image": "$inputs.image",
"predictions": "$steps.segmentation.predictions",
"mask_opacity": 1.0,
},
],
"outputs": [
{
"type": "JsonField",
"name": "crops",
"selector": "$steps.cropping.crops",
},
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.segmentation.predictions",
},
],
}


@add_to_workflows_gallery(
category="Workflows with data transformations",
use_case_title="Instance Segmentation results with background subtracted",
use_case_description="""
This example showcases how to extract all instances detected by instance segmentation model
as separate crops without background.
""",
workflow_definition=MASKED_CROP_WORKFLOW,
workflow_name_in_app="segmentation-plus-masked-crop",
)
def test_workflow_with_masked_crop(
model_manager: ModelManager,
dogs_image: np.ndarray,
roboflow_api_key: str,
) -> None:
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.api_key": roboflow_api_key,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=MASKED_CROP_WORKFLOW,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": dogs_image,
}
)

assert isinstance(result, list), "Expected list to be delivered"
assert len(result) == 1, "Expected 1 element in the output for one input image"
assert set(result[0].keys()) == {
"crops",
"predictions",
}, "Expected all declared outputs to be delivered"
assert len(result[0]["crops"]) == 2, "Expected 2 crops for two dogs detected"
crop_image = result[0]["crops"][0].numpy_image
(x_min, y_min, x_max, y_max) = (
result[0]["predictions"].xyxy[0].round().astype(dtype=int)
)
crop_mask = result[0]["predictions"].mask[0][y_min:y_max, x_min:x_max]
pixels_outside_mask = np.where(
np.stack([crop_mask] * 3, axis=-1) == 0,
crop_image,
np.zeros_like(crop_image),
)
pixels_sum = pixels_outside_mask.sum()
assert pixels_sum == 0, "Expected everything black outside mask"


def test_workflow_with_masked_crop_when_nothing_gets_predicted(
model_manager: ModelManager,
dogs_image: np.ndarray,
roboflow_api_key: str,
) -> None:
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.api_key": roboflow_api_key,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=MASKED_CROP_WORKFLOW,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={"image": dogs_image, "confidence": 0.99}
)

assert isinstance(result, list), "Expected list to be delivered"
assert len(result) == 1, "Expected 1 element in the output for one input image"
assert set(result[0].keys()) == {
"crops",
"predictions",
}, "Expected all declared outputs to be delivered"
assert len(result[0]["crops"]) == 0, "Expected 0 crops detected"
Loading

0 comments on commit 981dc55

Please sign in to comment.