Skip to content

Commit

Permalink
Merge branch 'main' into fix/yolo-settings
Browse files Browse the repository at this point in the history
  • Loading branch information
PawelPeczek-Roboflow authored Nov 14, 2024
2 parents aa0a209 + 8193970 commit cd1bb27
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 29 deletions.
4 changes: 4 additions & 0 deletions inference/core/workflows/core_steps/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@
from inference.core.workflows.core_steps.models.foundation.florence2.v1 import (
Florence2BlockV1,
)
from inference.core.workflows.core_steps.models.foundation.florence2.v2 import (
Florence2BlockV2,
)
from inference.core.workflows.core_steps.models.foundation.google_gemini.v1 import (
GoogleGeminiBlockV1,
)
Expand Down Expand Up @@ -452,6 +455,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
DotVisualizationBlockV1,
EllipseVisualizationBlockV1,
Florence2BlockV1,
Florence2BlockV2,
GoogleGeminiBlockV1,
GoogleVisionOCRBlockV1,
HaloVisualizationBlockV1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
T = TypeVar("T")
K = TypeVar("K")

FLORENCE_TASKS_METADATA = {
"custom": {
"name": "Custom Prompt",
"description": "Use free-form prompt to generate a response. Useful with finetuned models.",
},
**VLM_TASKS_METADATA,
}

DETECTIONS_CLASS_NAME_FIELD = "class_name"
DETECTION_ID_FIELD = "detection_id"

Expand Down Expand Up @@ -75,12 +83,13 @@
},
{"task_type": "detection-grounded-ocr", "florence_task": "<REGION_TO_OCR>"},
{"task_type": "region-proposal", "florence_task": "<REGION_PROPOSAL>"},
{"task_type": "custom", "florence_task": None},
]
TASK_TYPE_TO_FLORENCE_TASK = {
task["task_type"]: task["florence_task"] for task in SUPPORTED_TASK_TYPES_LIST
}
RELEVANT_TASKS_METADATA = {
k: v for k, v in VLM_TASKS_METADATA.items() if k in TASK_TYPE_TO_FLORENCE_TASK
k: v for k, v in FLORENCE_TASKS_METADATA.items() if k in TASK_TYPE_TO_FLORENCE_TASK
}
RELEVANT_TASKS_DOCS_DESCRIPTION = "\n\n".join(
f"* **{v['name']}** (`{k}`) - {v['description']}"
Expand Down Expand Up @@ -125,6 +134,7 @@
TASKS_REQUIRING_PROMPT = {
"phrase-grounded-object-detection",
"phrase-grounded-instance-segmentation",
"custom",
}
TASKS_REQUIRING_CLASSES = {
"open-vocabulary-object-detection",
Expand All @@ -145,31 +155,8 @@
}


class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "Florence-2 Model",
"version": "v1",
"short_description": "Run Florence-2 on an image",
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "model",
"search_keywords": ["Florence", "Florence-2", "Microsoft"],
"is_vlm_block": True,
"task_type_property": "task_type",
},
protected_namespaces=(),
)
type: Literal["roboflow_core/florence_2@v1"]
class BaseManifest(WorkflowBlockManifest):
images: Selector(kind=[IMAGE_KIND]) = ImageInputField
model_version: Union[
Selector(kind=[STRING_KIND]),
Literal["florence-2-base", "florence-2-large"],
] = Field(
default="florence-2-base",
description="Model to be used",
examples=["florence-2-base"],
)
task_type: TaskType = Field(
default="open-vocabulary-object-detection",
description="Task type to be performed by model. "
Expand Down Expand Up @@ -294,6 +281,32 @@ def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.3.0,<2.0.0"


class BlockManifest(BaseManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "Florence-2 Model",
"version": "v1",
"short_description": "Run Florence-2 on an image",
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "model",
"search_keywords": ["Florence", "Florence-2", "Microsoft"],
"is_vlm_block": True,
"task_type_property": "task_type",
},
protected_namespaces=(),
)
type: Literal["roboflow_core/florence_2@v1"]
model_version: Union[
Selector(kind=[STRING_KIND]),
Literal["florence-2-base", "florence-2-large"],
] = Field(
default="florence-2-base",
description="Model to be used",
examples=["florence-2-base"],
)


class Florence2BlockV1(WorkflowBlock):

def __init__(
Expand Down Expand Up @@ -358,6 +371,8 @@ def run_locally(
grounding_selection_mode: GroundingSelectionMode,
) -> BlockResult:
requires_detection_grounding = task_type in TASKS_REQUIRING_DETECTION_GROUNDING

is_not_florence_task = task_type == "custom"
task_type = TASK_TYPE_TO_FLORENCE_TASK[task_type]
inference_images = [
i.to_inference_format(numpy_preferred=False) for i in images
Expand Down Expand Up @@ -385,17 +400,27 @@ def run_locally(
{"raw_output": None, "parsed_output": None, "classes": None}
)
continue
if is_not_florence_task:
prompt = single_prompt or ""
else:
prompt = task_type + (single_prompt or "")

request = LMMInferenceRequest(
api_key=self._api_key,
model_id=model_version,
image=image,
source="workflow-execution",
prompt=task_type + (single_prompt or ""),
prompt=prompt,
)
prediction = self._model_manager.infer_from_request_sync(
model_id=model_version, request=request
)
prediction_data = prediction.response[task_type]
if is_not_florence_task:
prediction_data = prediction.response[
list(prediction.response.keys())[0]
]
else:
prediction_data = prediction.response[task_type]
if task_type in TASKS_TO_EXTRACT_LABELS_AS_CLASSES:
classes = prediction_data.get("labels", [])
predictions.append(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import List, Literal, Optional, Type, Union

import supervision as sv
from pydantic import ConfigDict, Field

from inference.core.workflows.core_steps.models.foundation.florence2.v1 import (
LONG_DESCRIPTION,
BaseManifest,
Florence2BlockV1,
GroundingSelectionMode,
TaskType,
)
from inference.core.workflows.execution_engine.entities.base import (
Batch,
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
ROBOFLOW_MODEL_ID_KIND,
WorkflowParameterSelector,
)
from inference.core.workflows.prototypes.block import BlockResult, WorkflowBlockManifest


class V2BlockManifest(BaseManifest):
type: Literal["roboflow_core/florence_2@v2"]
model_id: Union[WorkflowParameterSelector(kind=[ROBOFLOW_MODEL_ID_KIND]), str] = (
Field(
default="florence-2-base",
description="Model to be used",
examples=["florence-2-base"],
json_schema_extra={"always_visible": True},
)
)
model_config = ConfigDict(
json_schema_extra={
"name": "Florence-2 Model",
"version": "v2",
"short_description": "Run Florence-2 on an image",
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "model",
"search_keywords": ["Florence", "Florence-2", "Microsoft"],
"is_vlm_block": True,
"task_type_property": "task_type",
},
protected_namespaces=(),
)


class Florence2BlockV2(Florence2BlockV1):
@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return V2BlockManifest

def run(
self,
images: Batch[WorkflowImageData],
model_id: str,
task_type: TaskType,
prompt: Optional[str],
classes: Optional[List[str]],
grounding_detection: Optional[
Union[Batch[sv.Detections], List[int], List[float]]
],
grounding_selection_mode: GroundingSelectionMode,
) -> BlockResult:
return super().run(
images=images,
model_version=model_id,
task_type=task_type,
prompt=prompt,
classes=classes,
grounding_detection=grounding_detection,
grounding_selection_mode=grounding_selection_mode,
)
4 changes: 2 additions & 2 deletions inference/models/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs):
max_new_tokens=1000,
do_sample=False,
early_stopping=False,
no_repeat_ngram_size=0,
)
generation = generation[0]
if self.generation_includes_input:
generation = generation[input_len:]

decoded = self.processor.decode(
generation, skip_special_tokens=self.skip_special_tokens
)
Expand All @@ -151,7 +153,6 @@ def get_infer_bucket_file_list(self) -> list:
"config.json",
"special_tokens_map.json",
"generation_config.json",
"model.safetensors.index.json",
"tokenizer.json",
re.compile(r"model-\d{5}-of-\d{5}\.safetensors"),
"preprocessor_config.json",
Expand Down Expand Up @@ -286,7 +287,6 @@ def get_infer_bucket_file_list(self) -> list:
"adapter_config.json",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer.model",
"adapter_model.safetensors",
"preprocessor_config.json",
"tokenizer_config.json",
Expand Down

0 comments on commit cd1bb27

Please sign in to comment.