Skip to content

Commit

Permalink
Merge pull request #674 from roboflow/fix/class_remap_uql_operation_w…
Browse files Browse the repository at this point in the history
…ith_parameter

Add ability to parametrise class remap UQL operation
  • Loading branch information
PawelPeczek-Roboflow authored Sep 25, 2024
2 parents b779137 + b4ef640 commit aaacb2a
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,12 +468,13 @@ class DetectionsRename(OperationDefinition):
},
)
type: Literal["DetectionsRename"]
class_map: Dict[str, str] = Field(
description="Dictionary with classes replacement mapping"
class_map: Union[Dict[str, str], str] = Field(
description="Dictionary with classes replacement mapping or name of "
"parameter delivering the mapping"
)
strict: bool = Field(
description="Flag to decide if all class must be declared in `class_map`. When set `True` "
"all detections classes must be declared, otherwise error is raised.",
strict: Union[bool, str] = Field(
description="Flag to decide if all class must be declared in `class_map` or name of parameter delivering "
"the mapping. When set `True` all detections classes must be declared, otherwise error is raised.",
default=True,
)
new_classes_id_offset: int = Field(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import copy, deepcopy
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Union

import numpy as np
import supervision as sv
Expand All @@ -16,6 +16,7 @@
from inference.core.workflows.core_steps.common.query_language.errors import (
InvalidInputTypeError,
OperationError,
UndeclaredSymbolError,
)
from inference.core.workflows.core_steps.common.query_language.operations.utils import (
safe_stringify,
Expand Down Expand Up @@ -203,9 +204,10 @@ def sort_detections(

def rename_detections(
detections: Any,
class_map: Dict[str, str],
strict: bool,
class_map: Union[Dict[str, str], str],
strict: Union[bool, str],
new_classes_id_offset: int,
global_parameters: Dict[str, Any],
**kwargs,
) -> sv.Detections:
if not isinstance(detections, sv.Detections):
Expand All @@ -215,6 +217,37 @@ def rename_detections(
f"got {value_as_str} of type {type(detections)}",
context="step_execution | roboflow_query_language_evaluation",
)
if isinstance(class_map, str):
if class_map not in global_parameters:
raise UndeclaredSymbolError(
public_message=f"Attempted to retrieve variable `{class_map}` that was expected to hold "
f"class mapping of rename_detections(...), but that turned out not to be registered.",
context="step_execution | roboflow_query_language_evaluation",
)
class_map = global_parameters[class_map]
if not isinstance(class_map, dict):
value_as_str = safe_stringify(value=class_map)
raise InvalidInputTypeError(
public_message=f"Executing rename_detections(...), expected dictionary to be given as class map, "
f"got {value_as_str} of type {type(class_map)}",
context="step_execution | roboflow_query_language_evaluation",
)
if isinstance(strict, str):
if strict not in global_parameters:
raise UndeclaredSymbolError(
public_message=f"Attempted to retrieve variable `{strict}` that was expected to hold "
f"parameter for `strict` flag of rename_detections(...), but that turned out not "
f"to be registered.",
context="step_execution | roboflow_query_language_evaluation",
)
strict = global_parameters[strict]
if not isinstance(strict, bool):
value_as_str = safe_stringify(value=strict)
raise InvalidInputTypeError(
public_message=f"Executing rename_detections(...), expected dictionary to be given as `strict` flag, "
f"got {value_as_str} of type {type(strict)}",
context="step_execution | roboflow_query_language_evaluation",
)
detections_copy = deepcopy(detections)
original_class_names = detections_copy.data.get("class_name", []).tolist()
original_class_ids = detections_copy.class_id.tolist()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Dict

import numpy as np
Expand All @@ -7,7 +8,9 @@
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.core_steps.common.query_language.errors import (
InvalidInputTypeError,
OperationError,
UndeclaredSymbolError,
)
from inference.core.workflows.execution_engine.core import ExecutionEngine
from tests.workflows.integration_tests.execution.workflows_gallery_collector.decorators import (
Expand Down Expand Up @@ -211,3 +214,156 @@ def test_class_rename_workflow_with_strict_mapping_when_not_all_classes_are_rema
"model_id": "yolov8n-640",
},
)


WORKFLOW_WITH_PARAMETRISED_DETECTIONS_RENAME = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
{"type": "WorkflowParameter", "name": "confidence", "default_value": 0.4},
{"type": "WorkflowParameter", "name": "class_map"},
{"type": "WorkflowParameter", "name": "strict"},
],
"steps": [
{
"type": "ObjectDetectionModel",
"name": "model",
"image": "$inputs.image",
"model_id": "yolov8n-640",
"confidence": "$inputs.confidence",
},
{
"type": "DetectionsTransformation",
"name": "class_rename",
"predictions": "$steps.model.predictions",
"operations": [
{
"type": "DetectionsRename",
"strict": "strict",
"class_map": "class_map",
}
],
"operations_parameters": {
"class_map": "$inputs.class_map",
"strict": "$inputs.strict",
},
},
],
"outputs": [
{
"type": "JsonField",
"name": "original_predictions",
"selector": "$steps.model.predictions",
},
{
"type": "JsonField",
"name": "renamed_predictions",
"selector": "$steps.class_rename.predictions",
},
],
}


def test_class_rename_workflow_when_mapping_is_parametrised(
model_manager: ModelManager,
fruit_image: np.ndarray,
) -> None:
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.api_key": None,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_PARAMETRISED_DETECTIONS_RENAME,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": fruit_image,
"model_id": "yolov8n-640",
"class_map": {"apple": "fruit", "banana": "fruit"},
"strict": False,
},
)

# then
assert isinstance(result, list), "Expected result to be list"
assert len(result) == 1, "Single image provided - single output expected"

assert result[0]["renamed_predictions"]["class_name"].tolist() == [
"fruit",
"fruit",
"fruit",
"orange",
"fruit",
], "Expected renamed set of classes to be the same as when test was created"
assert result[0]["renamed_predictions"].class_id.tolist() == [
1024,
1024,
1024,
49,
1024,
], "Expected renamed set of class ids to be the same as when test was created"
assert len(result[0]["renamed_predictions"]) == len(
result[0]["original_predictions"]
), "Expected length of predictions no to change"


def test_class_rename_workflow_when_mapping_is_parametrised_with_invalid_value(
model_manager: ModelManager,
fruit_image: np.ndarray,
) -> None:
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.api_key": None,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_PARAMETRISED_DETECTIONS_RENAME,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
with pytest.raises(InvalidInputTypeError):
_ = execution_engine.run(
runtime_parameters={
"image": fruit_image,
"model_id": "yolov8n-640",
"class_map": "INVALID",
"strict": False,
},
)


def test_class_rename_workflow_when_mapping_is_not_passed_as_operation_parameter_leaving_undeclared_symbol(
model_manager: ModelManager,
fruit_image: np.ndarray,
) -> None:
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.api_key": None,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
workflow_definition = deepcopy(WORKFLOW_WITH_PARAMETRISED_DETECTIONS_RENAME)
del workflow_definition["steps"][1]["operations_parameters"]
execution_engine = ExecutionEngine.init(
workflow_definition=workflow_definition,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
with pytest.raises(UndeclaredSymbolError):
_ = execution_engine.run(
runtime_parameters={
"image": fruit_image,
"model_id": "yolov8n-640",
"class_map": "INVALID",
"strict": False,
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_rename_detections_when_not_sv_detections_provided() -> None:
class_map={"a": "b"},
strict=True,
new_classes_id_offset=0,
global_parameters={},
)


Expand All @@ -37,6 +38,7 @@ def test_rename_detections_when_strict_mode_enabled_and_all_classes_present() ->
class_map={"a": "A", "b": "B"},
strict=True,
new_classes_id_offset=1024,
global_parameters={},
)

# then
Expand Down Expand Up @@ -71,6 +73,7 @@ def test_rename_detections_when_strict_mode_enabled_and_not_all_classes_present(
class_map={"a": "A"},
strict=True,
new_classes_id_offset=1024,
global_parameters={},
)


Expand All @@ -91,6 +94,7 @@ def test_rename_detections_when_non_strict_mode_enabled_and_all_classes_present(
class_map={"a": "A", "b": "B"},
strict=False,
new_classes_id_offset=1024,
global_parameters={},
)

# then
Expand Down Expand Up @@ -124,6 +128,7 @@ def test_rename_detections_when_non_strict_mode_enabled_and_not_all_classes_pres
class_map={"a": "A"},
strict=False,
new_classes_id_offset=1024,
global_parameters={},
)

# then
Expand All @@ -138,3 +143,38 @@ def test_rename_detections_when_non_strict_mode_enabled_and_not_all_classes_pres
"A",
"b",
], "Expected to change with mapping"


def test_rename_detections_when_mapping_is_parametrised() -> None:
# given
detections = sv.Detections(
xyxy=np.array([[0, 1, 2, 3], [0, 1, 2, 3]]),
class_id=np.array([10, 11]),
confidence=np.array([0.3, 0.4]),
data={"class_name": np.array(["a", "b"])},
)

# when
result = rename_detections(
detections=detections,
class_map="class_map_param",
strict="strict_param",
new_classes_id_offset=1024,
global_parameters={
"strict_param": True,
"class_map_param": {"a": "A", "b": "B"},
},
)

# then
assert np.allclose(
result.xyxy, np.array([[0, 1, 2, 3], [0, 1, 2, 3]])
), "Expected no to change"
assert np.allclose(result.confidence, np.array([0.3, 0.4])), "Expected no to change"
assert np.allclose(
result.class_id, np.array([0, 1])
), "Expected to change with mapping"
assert result.data["class_name"].tolist() == [
"A",
"B",
], "Expected to change with mapping"

0 comments on commit aaacb2a

Please sign in to comment.