Skip to content

Added compatibility for Tensorflow 2.0 using V1 API #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def main():

Copy link

@toinsson toinsson Feb 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
with tf.compat.v1.Session(config=config) as sess:

This will allow RTX cards to work. See tensorflow/tensorflow#24496.

with tf.Session() as sess:
with tf.compat.v1.Session() as sess:
model_cfg, model_outputs = posenet.load_model(args.model, sess)
output_stride = model_cfg['output_stride']
num_images = args.num_images
Expand Down
80 changes: 80 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
name: pose-net
channels:
- defaults
dependencies:
- _tflow_select=2.1.0=gpu
- absl-py=0.9.0=py36_0
- asn1crypto=1.3.0=py36_0
- astor=0.8.0=py36_0
- blas=1.0=mkl
- blinker=1.4=py36_0
- ca-certificates=2020.1.1=0
- cachetools=3.1.1=py_0
- certifi=2019.11.28=py36_0
- cffi=1.14.0=py36h7a1dbc1_0
- chardet=3.0.4=py36_1003
- click=7.0=py36_0
- cryptography=2.8=py36h7a1dbc1_0
- cudatoolkit=10.1.243=h74a9793_0
- cudnn=7.6.5=cuda10.1_0
- gast=0.2.2=py36_0
- google-auth=1.11.2=py_0
- google-auth-oauthlib=0.4.1=py_2
- google-pasta=0.1.8=py_0
- grpcio=1.27.2=py36h351948d_0
- h5py=2.10.0=py36h5e291fa_0
- hdf5=1.10.4=h7ebc959_0
- icc_rt=2019.0.0=h0cc432a_1
- idna=2.8=py36_0
- intel-openmp=2020.0=166
- keras-applications=1.0.8=py_0
- keras-preprocessing=1.1.0=py_1
- libprotobuf=3.11.4=h7bd577a_0
- markdown=3.1.1=py36_0
- mkl=2020.0=166
- mkl-service=2.3.0=py36hb782905_0
- mkl_fft=1.0.15=py36h14836fe_0
- mkl_random=1.1.0=py36h675688f_0
- numpy=1.18.1=py36h93ca92e_0
- numpy-base=1.18.1=py36hc3f5095_1
- oauthlib=3.1.0=py_0
- openssl=1.1.1d=he774522_4
- opt_einsum=3.1.0=py_0
- pip=20.0.2=py36_1
- protobuf=3.11.4=py36h33f27b4_0
- pyasn1=0.4.8=py_0
- pyasn1-modules=0.2.7=py_0
- pycparser=2.19=py36_0
- pyjwt=1.7.1=py36_0
- pyopenssl=19.1.0=py36_0
- pyreadline=2.1=py36_1
- pysocks=1.7.1=py36_0
- python=3.6.10=h9f7ef89_0
- pyyaml=5.3=py36he774522_0
- requests=2.22.0=py36_1
- requests-oauthlib=1.3.0=py_0
- rsa=4.0=py_0
- scipy=1.4.1=py36h9439919_0
- setuptools=45.2.0=py36_0
- six=1.14.0=py36_0
- sqlite=3.31.1=he774522_0
- tensorboard=2.1.0=py3_0
- tensorflow=2.1.0=gpu_py36h3346743_0
- tensorflow-base=2.1.0=gpu_py36h55f5790_0
- tensorflow-estimator=2.1.0=pyhd54b08b_0
- tensorflow-gpu=2.1.0=h0d30ee6_0
- termcolor=1.1.0=py36_1
- urllib3=1.25.8=py36_0
- vc=14.1=h0510ff6_4
- vs2015_runtime=14.16.27012=hf0eaf9b_1
- werkzeug=0.14.1=py36_0
- wheel=0.34.2=py36_0
- win_inet_pton=1.1.0=py36_0
- wincertstore=0.2=py36h7fe50ca_0
- wrapt=1.11.2=py36he774522_0
- yaml=0.1.7=hc54c509_2
- zlib=1.2.11=h62dcd97_3
- pip:
- opencv-python==3.4.5.20
prefix: C:\Users\jorda\Anaconda3\envs\pose-net

2 changes: 1 addition & 1 deletion image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

def main():

with tf.Session() as sess:
with tf.compat.v1.Session() as sess:
model_cfg, model_outputs = posenet.load_model(args.model, sess)
output_stride = model_cfg['output_stride']

Expand Down
75 changes: 48 additions & 27 deletions posenet/converter/tfjs2python.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import struct
import tensorflow as tf
from tensorflow.python.tools.freeze_graph import freeze_graph

import cv2
import numpy as np
import os
Expand All @@ -20,7 +20,7 @@ def to_output_strided_layers(convolution_def, output_stride):
for _a in convolution_def:
conv_type = _a[0]
stride = _a[1]

if current_stride == output_stride:
layer_stride = 1
layer_rate = rate
Expand All @@ -29,7 +29,7 @@ def to_output_strided_layers(convolution_def, output_stride):
layer_stride = stride
layer_rate = 1
current_stride *= stride

buff.append({
'blockId': block_id,
'convType': conv_type,
Expand Down Expand Up @@ -76,7 +76,6 @@ def _read_imgfile(path, width, height):


def build_network(image, layers, variables):

def _weights(layer_name):
return variables["MobilenetV1/" + layer_name + "/weights"]['x']

Expand All @@ -87,13 +86,14 @@ def _depthwise_weights(layer_name):
return variables["MobilenetV1/" + layer_name + "/depthwise_weights"]['x']

def _conv_to_output(mobile_net_output, output_layer_name):
w = tf.nn.conv2d(mobile_net_output, _weights(output_layer_name), [1, 1, 1, 1], padding='SAME')
w = tf.nn.conv2d(input=mobile_net_output, filters=_weights(output_layer_name), strides=[1, 1, 1, 1],
padding='SAME')
w = tf.nn.bias_add(w, _biases(output_layer_name), name=output_layer_name)
return w

def _conv(inputs, stride, block_id):
return tf.nn.relu6(
tf.nn.conv2d(inputs, _weights("Conv2d_" + str(block_id)), stride, padding='SAME')
tf.nn.conv2d(input=inputs, filters=_weights("Conv2d_" + str(block_id)), strides=stride, padding='SAME')
+ _biases("Conv2d_" + str(block_id)))

def _separable_conv(inputs, stride, block_id, dilations):
Expand All @@ -104,19 +104,20 @@ def _separable_conv(inputs, stride, block_id, dilations):
pw_layer = "Conv2d_" + str(block_id) + "_pointwise"

w = tf.nn.depthwise_conv2d(
inputs, _depthwise_weights(dw_layer), stride, 'SAME', rate=dilations, data_format='NHWC')
input=inputs, filter=_depthwise_weights(dw_layer), strides=stride, padding='SAME', dilations=dilations,
data_format='NHWC')
w = tf.nn.bias_add(w, _biases(dw_layer))
w = tf.nn.relu6(w)

w = tf.nn.conv2d(w, _weights(pw_layer), [1, 1, 1, 1], padding='SAME')
w = tf.nn.conv2d(input=w, filters=_weights(pw_layer), strides=[1, 1, 1, 1], padding='SAME')
w = tf.nn.bias_add(w, _biases(pw_layer))
w = tf.nn.relu6(w)

return w

x = image
buff = []
with tf.variable_scope(None, 'MobilenetV1'):
with tf.compat.v1.variable_scope(None, 'MobilenetV1'):

for m in layers:
stride = [1, m['stride'], m['stride'], 1]
Expand Down Expand Up @@ -162,12 +163,12 @@ def convert(model_id, model_dir, check=False):
layers = to_output_strided_layers(mobile_net_arch, output_stride)
variables = load_variables(chkpoint)

init = tf.global_variables_initializer()
with tf.Session() as sess:
init = tf.compat.v1.global_variables_initializer()
with tf.compat.v1.Session() as sess:
sess.run(init)
saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()

image_ph = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='image')
image_ph = tf.compat.v1.placeholder(tf.float32, shape=[1, None, None, 3], name='image')
outputs = build_network(image_ph, layers, variables)

sess.run(
Expand All @@ -182,22 +183,13 @@ def convert(model_id, model_dir, check=False):
os.makedirs(os.path.dirname(save_path))
checkpoint_path = saver.save(sess, save_path, write_state=False)

tf.train.write_graph(cg, model_dir, "model-%s.pbtxt" % chkpoint)
tf.io.write_graph(cg, model_dir, "model-%s.pbtxt" % chkpoint)

# Freeze graph and write our final model file
freeze_graph(
input_graph=os.path.join(model_dir, "model-%s.pbtxt" % chkpoint),
input_saver="",
input_binary=False,
input_checkpoint=checkpoint_path,
output_node_names='heatmap,offset_2,displacement_fwd_2,displacement_bwd_2',
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
output_graph=os.path.join(model_dir, "model-%s.pb" % chkpoint),
clear_devices=True,
initializer_nodes="")

if check and os.path.exists("./images/tennis_in_crowd.jpg"):
frozen_graph = freeze_session(sess, None, ['heatmap','offset_2','displacement_fwd_2','displacement_bwd_2'], True)
tf.compat.v1.train.write_graph(frozen_graph, './', os.path.join(model_dir, "model-%s.pb" % chkpoint), as_text=False)

if os.path.exists("./images/tennis_in_crowd.jpg"):
# Result
input_image = _read_imgfile("./images/tennis_in_crowd.jpg", width, height)
input_image = np.array(input_image, dtype=np.float32)
Expand All @@ -219,3 +211,32 @@ def convert(model_id, model_dir, check=False):
print(heatmaps_result[0:1, 0:1, :])
print(heatmaps_result.shape)
print(np.mean(heatmaps_result))


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.

Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.compat.v1.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
4 changes: 2 additions & 2 deletions posenet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def load_model(model_id, sess, model_dir=MODEL_DIR):
convert(model_ord, model_dir, check=False)
assert os.path.exists(model_path)

with tf.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
with tf.compat.v1.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
Expand Down
2 changes: 1 addition & 1 deletion webcam_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def main():
with tf.Session() as sess:
with tf.compat.v1.Session() as sess:
model_cfg, model_outputs = posenet.load_model(args.model, sess)
output_stride = model_cfg['output_stride']

Expand Down