Skip to content
This repository has been archived by the owner on Mar 19, 2023. It is now read-only.

Commit

Permalink
Fix confidence
Browse files Browse the repository at this point in the history
  • Loading branch information
robmarkcole committed Jan 13, 2021
1 parent 9c16992 commit 8bfc960
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions custom_components/deepstack_object/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from homeassistant.util.pil import draw_box
from homeassistant.components.image_processing import (
ATTR_CONFIDENCE,
CONF_CONFIDENCE,
CONF_ENTITY_ID,
CONF_NAME,
CONF_SOURCE,
Expand Down Expand Up @@ -79,7 +80,7 @@

DATETIME_FORMAT = "%Y-%m-%d_%H-%M-%S"
DEFAULT_API_KEY = ""
DEFAULT_TARGETS = [{CONF_TARGET: PERSON, ATTR_CONFIDENCE: DEFAULT_CONFIDENCE}]
DEFAULT_TARGETS = [{CONF_TARGET: PERSON}]
DEFAULT_TIMEOUT = 10
DEFAULT_ROI_Y_MIN = 0.0
DEFAULT_ROI_Y_MAX = 1.0
Expand All @@ -97,6 +98,7 @@
FILE = "file"
OBJECT = "object"
SAVED_FILE = "saved_file"
MIN_CONFIDENCE = 0.01

# rgb(red, green, blue)
RED = (255, 0, 0) # For objects within the ROI
Expand All @@ -105,8 +107,8 @@

TARGETS_SCHEMA = {
vol.Required(CONF_TARGET): cv.string,
vol.Optional(ATTR_CONFIDENCE, default=DEFAULT_CONFIDENCE): vol.All(
vol.Coerce(float), vol.Range(min=0, max=100)
vol.Optional(CONF_CONFIDENCE): vol.All(
vol.Coerce(float), vol.Range(min=1, max=100)
),
}

Expand Down Expand Up @@ -207,7 +209,6 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
if save_file_folder:
save_file_folder = Path(save_file_folder)

targets = config[CONF_TARGETS] # ensure lower case
entities = []
for camera in config[CONF_SOURCE]:
object_entity = ObjectClassifyEntity(
Expand All @@ -216,8 +217,8 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
api_key=config.get(CONF_API_KEY),
timeout=config.get(CONF_TIMEOUT),
custom_model=config.get(CONF_CUSTOM_MODEL),
targets=targets,
confidence=config.get(ATTR_CONFIDENCE),
targets=config.get(CONF_TARGETS),
confidence=config.get(CONF_CONFIDENCE),
roi_y_min=config[CONF_ROI_Y_MIN],
roi_x_min=config[CONF_ROI_X_MIN],
roi_y_max=config[CONF_ROI_Y_MAX],
Expand Down Expand Up @@ -261,15 +262,18 @@ def __init__(
port=port,
api_key=api_key,
timeout=timeout,
min_confidence=confidence / 100,
min_confidence=MIN_CONFIDENCE,
custom_model=custom_model,
)
self._custom_model = custom_model
self._confidence = confidence
self._targets = targets
for target in self._targets:
if CONF_CONFIDENCE not in target.keys():
target[CONF_CONFIDENCE] = self._confidence
self._targets_names = [
target[CONF_TARGET] for target in targets
] # can be a name or a type
self._confidence = confidence
self._camera = camera_entity
if name:
self._name = name
Expand Down Expand Up @@ -319,19 +323,16 @@ def process_image(self, image):

for obj in self._objects:
if obj["name"] or obj["object_type"] in self._targets_names:
## Retreive target confidence, if configured
## First check if the type has a configured confidence
## Then check if the type has a configured confidence, if yes assign
## Then if a confidence for a named object, this takes precedence over type confidence
obj_confidence = None
confidence = None
for target in self._targets:
if target[CONF_TARGET] == obj["object_type"]:
obj_confidence = target[ATTR_CONFIDENCE]
confidence = target[CONF_CONFIDENCE]
for target in self._targets:
if target[CONF_TARGET] == obj["name"]:
obj_confidence = target[ATTR_CONFIDENCE]
if not obj_confidence:
obj_confidence = self._confidence
if obj["confidence"] > obj_confidence:
confidence = target[CONF_CONFIDENCE]
if obj["confidence"] > confidence:
if object_in_roi(self._roi_dict, obj["centroid"]):
self._targets_found.append(obj)

Expand Down Expand Up @@ -382,6 +383,7 @@ def device_state_attributes(self) -> Dict:
if self._custom_model:
attr["custom_model"] = self._custom_model
attr["summary"] = self._summary
attr["objects"] = [{obj["name"]: obj["confidence"]} for obj in self._objects]
if self._save_file_folder:
attr[CONF_SAVE_FILE_FOLDER] = str(self._save_file_folder)
if self._save_timestamped_file:
Expand Down

0 comments on commit 8bfc960

Please sign in to comment.