Skip to content

Commit 58a4611

Browse files
type hints for benchmark (and also some missing ones in dlclive)
1 parent 4004ef9 commit 58a4611

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

dlclive/benchmark.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import sys
1313
import warnings
1414
import subprocess
15-
import typing
15+
from typing import List, Optional, Tuple, Union
1616
import pickle
1717
import colorcet as cc
1818
from PIL import ImageColor
@@ -148,22 +148,22 @@ def get_system_info() -> dict:
148148

149149

150150
def benchmark(
151-
model_path,
152-
video_path,
153-
tf_config=None,
154-
resize=None,
155-
pixels=None,
156-
cropping=None,
157-
dynamic=(False, 0.5, 10),
158-
n_frames=1000,
159-
print_rate=False,
160-
display=False,
161-
pcutoff=0.0,
162-
display_radius=3,
163-
cmap="bmy",
164-
save_poses=False,
165-
save_video=False,
166-
output=None,
151+
model_path: str,
152+
video_path: str,
153+
tf_config: Optional[tf.ConfigProto] = None,
154+
resize: Optional[float] = None,
155+
pixels: Optional[int] = None,
156+
cropping: Optional[List[int]] = None,
157+
dynamic: Tuple[bool, float, int] = (False, 0.5, 10),
158+
n_frames: int = 1000,
159+
print_rate: bool = False,
160+
display: bool = False,
161+
pcutoff: float = 0.0,
162+
display_radius: int = 3,
163+
cmap: str = "bmy",
164+
save_poses: bool = False,
165+
save_video: bool = False,
166+
output: Optional[str] = None,
167167
) -> typing.Tuple[np.ndarray, tuple, bool, dict]:
168168
""" Analyze DeepLabCut-live exported model on a video:
169169
Calculate inference time,
@@ -516,22 +516,22 @@ def save_inf_times(
516516

517517

518518
def benchmark_videos(
519-
model_path,
520-
video_path,
521-
output=None,
522-
n_frames=1000,
523-
tf_config=None,
524-
resize=None,
525-
pixels=None,
526-
cropping=None,
527-
dynamic=(False, 0.5, 10),
528-
print_rate=False,
529-
display=False,
530-
pcutoff=0.5,
531-
display_radius=3,
532-
cmap="bmy",
533-
save_poses=False,
534-
save_video=False,
519+
model_path: str,
520+
video_path: Union[str, List[str]],
521+
output: Optional[str] = None,
522+
n_frames: int = 1000,
523+
tf_config: Optional[tf.ConfigProto] = None,
524+
resize: Optional[Union[float, List[float]]] = None,
525+
pixels: Optional[Union[int, List[int]]] = None,
526+
cropping: Optional[List[int]] = None,
527+
dynamic: Tuple[bool, float, int] = (False, 0.5, 10),
528+
print_rate: bool = False,
529+
display: bool = False,
530+
pcutoff: float = 0.5,
531+
display_radius: int = 3,
532+
cmap: str = "bmy",
533+
save_poses: bool = False,
534+
save_video: bool = False,
535535
):
536536
"""Analyze videos using DeepLabCut-live exported models.
537537
Analyze multiple videos and/or multiple options for the size of the video

dlclive/dlclive.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def parameterization(self) -> dict:
181181
"""
182182
return {param: getattr(self, param) for param in self.PARAMETERS}
183183

184-
def process_frame(self, frame):
184+
def process_frame(self, frame: np.ndarray) -> np.ndarray:
185185
"""
186186
Crops an image according to the object's cropping and dynamic properties.
187187
@@ -237,7 +237,7 @@ def process_frame(self, frame):
237237

238238
return frame
239239

240-
def init_inference(self, frame=None, **kwargs):
240+
def init_inference(self, frame=None, **kwargs) -> np.ndarray:
241241
"""
242242
Load model and perform inference on first frame -- the first inference is usually very slow.
243243
@@ -376,7 +376,7 @@ def init_inference(self, frame=None, **kwargs):
376376

377377
return pose
378378

379-
def get_pose(self, frame=None, **kwargs):
379+
def get_pose(self, frame=None, **kwargs) -> np.ndarray:
380380
"""
381381
Get the pose of an image
382382

0 commit comments

Comments
 (0)