Skip to content

Commit

Permalink
Clean pre_hook/post_hook of JointInference lib code
Browse files Browse the repository at this point in the history
  • Loading branch information
khalid-huang committed Jan 11, 2021
1 parent 693263f commit 4925635
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
10 changes: 4 additions & 6 deletions examples/helmet_detection_inference/little_model/little_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def preprocess(image, input_shape):
new_image.fill(128)
bh, bw, _ = new_image.shape
new_image[int((bh - nh) / 2):(nh + int((bh - nh) / 2)),
int((bw - nw) / 2):(nw + int((bw - nw) / 2)), :] = image
int((bw - nw) / 2):(nw + int((bw - nw) / 2)), :] = image

new_image /= 255.
new_image = np.expand_dims(new_image, 0) # Add batch dimension.
Expand Down Expand Up @@ -112,7 +112,7 @@ def create_output_fetch(sess):
return output_fetch


def post_hook(model_output):
def postprocess(model_output):
all_classes, all_scores, all_bboxes = model_output
bboxes = []
for c, s, bbox in zip(all_classes, all_scores, all_bboxes):
Expand Down Expand Up @@ -171,7 +171,7 @@ def run():
# create little model object
model = neptune.joint_inference.TSLittleModel(
preprocess=preprocess,
postprocess=None,
postprocess=postprocess,
input_shape=input_shape,
create_input_feed=create_input_feed,
create_output_fetch=create_output_fetch
Expand All @@ -188,9 +188,7 @@ def run():
# create joint inference object
inference_instance = neptune.joint_inference.JointInference(
little_model=model,
hard_example_mining_algorithm=hard_example_mining_algorithm,
pre_hook=None,
post_hook=post_hook,
hard_example_mining_algorithm=hard_example_mining_algorithm
)

# use video streams for testing
Expand Down
19 changes: 7 additions & 12 deletions lib/neptune/joint_inference/joint_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def inference(self, img_data):
input_feed = self.create_input_feed(self.session, new_image,
img_data_np)
output_fetch = self.create_output_fetch(self.session)
return self.session.run(output_fetch, input_feed)
output = self.session.run(output_fetch, input_feed)
if self.postprocess:
output = self.postprocess(output)
return output


class LCReporter(threading.Thread):
Expand Down Expand Up @@ -294,13 +297,10 @@ class JointInference:
:param little_model: the little model entity for edge inference
:param hard_example_mining_algorithm: the algorithm for judging hard
example
:param pre_hook: the pre function of edge inference
:param post_hook: the post function of edge inference
"""

def __init__(self, little_model: BaseModel,
hard_example_mining_algorithm=None,
pre_hook=None, post_hook=None):
hard_example_mining_algorithm=None):
self.little_model = little_model
self.big_model = BigModelClient()
# TODO how to deal process use-defined cloud_offload_algorithm,
Expand Down Expand Up @@ -329,8 +329,6 @@ def __init__(self, little_model: BaseModel,
hard_example_mining_algorithm = ThresholdFilter()

self.hard_example_mining_algorithm = hard_example_mining_algorithm
self.pre_hook = pre_hook
self.post_hook = post_hook

self.lc_reporter = LCReporter()
self.lc_reporter.setDaemon(True)
Expand All @@ -339,13 +337,10 @@ def __init__(self, little_model: BaseModel,
def inference(self, img_data) -> InferenceResult:
"""Image inference function."""
img_data_pre = img_data
if self.pre_hook:
img_data_pre = self.pre_hook(img_data_pre)
edge_result = self.little_model.inference(img_data_pre)
if self.post_hook:
edge_result = self.post_hook(edge_result)
is_hard_example = self.hard_example_mining_algorithm.hard_judge(
edge_result)
edge_result
)
if not is_hard_example:
LOG.debug("not hard example, use edge result directly")
self.lc_reporter.update_for_edge_inference()
Expand Down

0 comments on commit 4925635

Please sign in to comment.