Skip to content

Commit

Permalink
Merge pull request #714 from roboflow/multi-label-classification-oper…
Browse files Browse the repository at this point in the history
…ations

Multi-Label Classification UQL Operations
  • Loading branch information
PawelPeczek-Roboflow authored Oct 3, 2024
2 parents bba3742 + 63c345c commit cc09811
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,36 @@ class In(BinaryOperator):
type: Literal["in (Sequence)"]


class AllInSequence(BinaryOperator):
model_config = ConfigDict(
json_schema_extra={
"description": "Checks if all elements of first value are elements of second value (usually list)",
"operands_number": 2,
"operands_kinds": [
[LIST_OF_VALUES_KIND],
[LIST_OF_VALUES_KIND],
],
"output_kind": [BOOLEAN_KIND],
},
)
type: Literal["all in (Sequence)"]


class AnyInSequence(BinaryOperator):
model_config = ConfigDict(
json_schema_extra={
"description": "Checks if any element of first value is element of second value (usually list)",
"operands_number": 2,
"operands_kinds": [
[LIST_OF_VALUES_KIND],
[LIST_OF_VALUES_KIND],
],
"output_kind": [BOOLEAN_KIND],
},
)
type: Literal["any in (Sequence)"]


class UnaryOperator(BaseModel):
type: str

Expand Down Expand Up @@ -839,6 +869,8 @@ class BinaryStatement(BaseModel):
comparator: Annotated[
Union[
In,
AllInSequence,
AnyInSequence,
StringContains,
StringEndsWith,
StringStartsWith,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"(String) endsWith": lambda a, b: a.endswith(b),
"(String) contains": lambda a, b: b in a,
"in (Sequence)": lambda a, b: a in b,
"any in (Sequence)": lambda a, b: any(item in b for item in a),
"all in (Sequence)": lambda a, b: all(item in b for item in a),
"(Detection) in zone": is_point_in_zone,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def prepare_multi_label_classification_prompt(
{
"role": "system",
"content": "You act as multi-label classification model. You must provide reasonable predictions. "
"You are only allowed to produce JSON document in Markdown ```json [...]``` markers. "
"You are only allowed to produce JSON document in Markdown ```json``` markers. "
'Expected structure of json: {"predicted_classes": [{"class": "class-name-1", "confidence": 0.9}, '
'{"class": "class-name-2", "confidence": 0.7}]}. '
"`class-name-X` must be one of the class names defined by user and `confidence` is a float value in range "
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import pytest
from inference.core.workflows.core_steps.common.query_language.evaluation_engine.core import (
build_eval_function,
)
from inference.core.workflows.core_steps.common.query_language.entities.operations import (
StatementGroup,
)

CONTINUE_IF_MULTI_LABEL_EXACT_MATCH_PASS = {
"type": "roboflow_core/continue_if@v1",
"name": "exact_match",
"condition_statement": {
"type": "StatementGroup",
"statements": [
{
"type": "BinaryStatement",
"left_operand": {"type": "StaticOperand", "value": ["lion", "zebra"]},
"comparator": {"type": "=="},
"right_operand": {"type": "StaticOperand", "value": ["lion", "zebra"]},
}
],
},
"next_steps": ["$steps.flip"],
"evaluation_parameters": {},
}

CONTINUE_IF_MULTI_LABEL_EXACT_MATCH_FAIL = {
"type": "roboflow_core/continue_if@v1",
"name": "exact_match",
"condition_statement": {
"type": "StatementGroup",
"statements": [
{
"type": "BinaryStatement",
"left_operand": {"type": "StaticOperand", "value": ["lion", "zebra"]},
"comparator": {"type": "=="},
"right_operand": {
"type": "StaticOperand",
"value": ["lion", "zebra", "elephant"],
},
}
],
},
"next_steps": ["$steps.flip"],
"evaluation_parameters": {},
}


CONTINUE_IF_MULTI_LABEL_ANY_MATCH_PASS = {
"type": "roboflow_core/continue_if@v1",
"name": "any_in",
"condition_statement": {
"type": "StatementGroup",
"statements": [
{
"type": "BinaryStatement",
"left_operand": {"type": "StaticOperand", "value": ["lion", "zebra"]},
"comparator": {"type": "any in (Sequence)"},
"right_operand": {
"type": "StaticOperand",
"value": ["cat", "zebra", "dog"],
},
}
],
},
"next_steps": ["$steps.flip"],
"evaluation_parameters": {},
}

CONTINUE_IF_MULTI_LABEL_ANY_MATCH_FAIL = {
"type": "roboflow_core/continue_if@v1",
"name": "any_in",
"condition_statement": {
"type": "StatementGroup",
"statements": [
{
"type": "BinaryStatement",
"left_operand": {"type": "StaticOperand", "value": ["lion", "zebra"]},
"comparator": {"type": "any in (Sequence)"},
"right_operand": {
"type": "StaticOperand",
"value": ["cat", "elephant", "dog"],
},
}
],
},
"next_steps": ["$steps.flip"],
"evaluation_parameters": {},
}


CONTINUE_IF_MULTI_LABEL_ALL_MATCH_PASS = {
"type": "roboflow_core/continue_if@v1",
"name": "all_in",
"condition_statement": {
"type": "StatementGroup",
"statements": [
{
"type": "BinaryStatement",
"left_operand": {"type": "StaticOperand", "value": ["lion", "zebra"]},
"comparator": {"type": "all in (Sequence)"},
"right_operand": {"type": "StaticOperand", "value": ["zebra", "lion"]},
}
],
},
"next_steps": ["$steps.flip"],
"evaluation_parameters": {"left": "$steps.multi_label_classes.predictions"},
}

CONTINUE_IF_MULTI_LABEL_ALL_MATCH_FAIL = {
"type": "roboflow_core/continue_if@v1",
"name": "all_in",
"condition_statement": {
"type": "StatementGroup",
"statements": [
{
"type": "BinaryStatement",
"left_operand": {"type": "StaticOperand", "value": ["lion", "dog"]},
"comparator": {"type": "all in (Sequence)"},
"right_operand": {"type": "StaticOperand", "value": ["zebra", "lion"]},
}
],
},
"next_steps": ["$steps.flip"],
"evaluation_parameters": {"left": "$steps.multi_label_classes.predictions"},
}


@pytest.mark.parametrize(
"condition_statement, evaluation_parameters, expected_result",
[
(
CONTINUE_IF_MULTI_LABEL_EXACT_MATCH_PASS["condition_statement"],
CONTINUE_IF_MULTI_LABEL_EXACT_MATCH_PASS["evaluation_parameters"],
True,
),
(
CONTINUE_IF_MULTI_LABEL_EXACT_MATCH_FAIL["condition_statement"],
CONTINUE_IF_MULTI_LABEL_EXACT_MATCH_FAIL["evaluation_parameters"],
False,
),
(
CONTINUE_IF_MULTI_LABEL_ANY_MATCH_PASS["condition_statement"],
CONTINUE_IF_MULTI_LABEL_ANY_MATCH_PASS["evaluation_parameters"],
True,
),
(
CONTINUE_IF_MULTI_LABEL_ANY_MATCH_FAIL["condition_statement"],
CONTINUE_IF_MULTI_LABEL_ANY_MATCH_FAIL["evaluation_parameters"],
False,
),
(
CONTINUE_IF_MULTI_LABEL_ALL_MATCH_PASS["condition_statement"],
CONTINUE_IF_MULTI_LABEL_ALL_MATCH_PASS["evaluation_parameters"],
True,
),
(
CONTINUE_IF_MULTI_LABEL_ALL_MATCH_FAIL["condition_statement"],
CONTINUE_IF_MULTI_LABEL_ALL_MATCH_FAIL["evaluation_parameters"],
False,
),
],
)
def test_continue_if_evaluation(
condition_statement, evaluation_parameters, expected_result
):
parsed_definition = StatementGroup.model_validate(condition_statement)
evaluation_function = build_eval_function(definition=parsed_definition)
evaluation_result = evaluation_function(evaluation_parameters)
assert (
evaluation_result == expected_result
), f"Expected {expected_result} for condition {condition_statement} with parameters {evaluation_parameters}, but got {evaluation_result}"

0 comments on commit cc09811

Please sign in to comment.