Skip to content

Commit

Permalink
Merge graph extension and visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
fschaeffler93 authored Apr 15, 2019
1 parent e314797 commit 568bd5a
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,65 @@ def predict_step(self, session, data, summary=False, vis=False):
return tag, ret_box3d_score


def save_frozen_graph(self, session, output_path):
"""
saves a frozen graph.pb file in the given path
:param session: current tensorflow session
:param output_path: path to the output frozen_graph.pb
:return: ---
"""
output_graph = tf.graph_util.convert_variables_to_constants(
session,
tf.get_default_graph().as_graph_def(),
self.get_output_nodes_names()
)

with tf.gfile.GFile(output_path, "wb") as f:
f.write(output_graph.SerializeToString())

print("%d ops in the final graph." % len(output_graph.node))
print("\n\n frozen graph saved to {}".format(output_path))

def get_output_nodes(self):
"""
:return: list of all output nodes
"""
ret = [
self.prob_output, self.delta_output,
self.box2d_ind_after_nms,
self.predict_summary
]

return ret

def get_output_nodes_names(self):
"""
:return: list of the names of all output nodes (string)
"""
nodes = self.get_output_nodes()
ret = []
for nd in nodes:
ret.append(nd.name.split(":")[0])

print(ret)
return ret


def get_input_nodes(self):
"""
:return: list of all input nodes
"""
ret = [
self.boxes2d, self.boxes2d_scores,
self.rgb, self.bv, self.bv_heatmap
]
for idx in range(len(self.avail_gpus)):
ret.append(self.vox_feature[idx])
ret.append(self.vox_number[idx])
ret.append(self.vox_coordinate[idx])

return ret

def average_gradients(tower_grads):
# ref:
# https://github.com/tensorflow/models/blob/6db9f0282e2ab12795628de6200670892a8ad6ba/tutorials/image/cifar10/cifar10_multi_gpu_train.py#L103
Expand All @@ -355,7 +414,8 @@ def average_gradients(tower_grads):
return average_grads




if __name__ == '__main__':
pass


0 comments on commit 568bd5a

Please sign in to comment.