Skip to content

Commit

Permalink
fix runtime issue
Browse files Browse the repository at this point in the history
  • Loading branch information
david8862 committed Feb 7, 2023
1 parent a406224 commit 752a06c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
5 changes: 4 additions & 1 deletion common/backbones/imagenet_training/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from yolo3.models.yolo3_darknet import DarkNet53
from yolo4.models.yolo4_darknet import CSPDarkNet53

#from common.utils import optimize_tf_gpu
from common.utils import optimize_tf_gpu
from common.model_utils import get_optimizer
from common.callbacks import CheckpointCleanCallBack

Expand All @@ -46,6 +46,9 @@
## set session
#K.set_session(session)

optimize_tf_gpu(tf, K)


def preprocess(image):
# random adjust color level
image = random_chroma(image)
Expand Down
5 changes: 4 additions & 1 deletion tools/evaluation/model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..'))
from common.utils import get_custom_objects
from common.utils import get_custom_objects, optimize_tf_gpu

# check tf version to be compatible with TF 2.x
if tf.__version__.startswith('2'):
Expand All @@ -19,6 +19,9 @@

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

optimize_tf_gpu(tf, K)


def clever_format(value, format_string="%.2f"):
"""
Convert statistic value to clever format string
Expand Down
4 changes: 3 additions & 1 deletion tools/evaluation/validate_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from yolo3.postprocess_np import yolo3_postprocess_np
from yolo2.postprocess_np import yolo2_postprocess_np
from common.data_utils import preprocess_image
from common.utils import get_classes, get_anchors, get_colors, draw_boxes, get_custom_objects
from common.utils import get_classes, get_anchors, get_colors, draw_boxes, get_custom_objects, optimize_tf_gpu

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

optimize_tf_gpu(tf, K)


def validate_yolo_model(model, image_file, anchors, class_names, model_input_shape, elim_grid_sense, v5_decode, loop_count, output_path):
img = Image.open(image_file).convert('RGB')
Expand Down
7 changes: 6 additions & 1 deletion tracking/mot_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from timeit import time
from collections import deque

import tensorflow as tf
import tensorflow.keras.backend as K

# implementation of SORT tracker
from model.sort.sort import Sort
# implementation of DeepSORT tracker
Expand All @@ -19,7 +22,9 @@

sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..'))
from yolo import YOLO, YOLO_np
from common.utils import get_classes
from common.utils import get_classes, optimize_tf_gpu

optimize_tf_gpu(tf, K)


def get_frame(frame_capture, i, images_input):
Expand Down

0 comments on commit 752a06c

Please sign in to comment.