diff --git a/model/model.py b/model/model.py index 9265967..9b25aa3 100644 --- a/model/model.py +++ b/model/model.py @@ -329,6 +329,65 @@ def predict_step(self, session, data, summary=False, vis=False): return tag, ret_box3d_score + def save_frozen_graph(self, session, output_path): + """ + saves a frozen graph.pb file in the given path + :param session: current tensorflow session + :param output_path: path to the output frozen_graph.pb + :return: --- + """ + output_graph = tf.graph_util.convert_variables_to_constants( + session, + tf.get_default_graph().as_graph_def(), + self.get_output_nodes_names() + ) + + with tf.gfile.GFile(output_path, "wb") as f: + f.write(output_graph.SerializeToString()) + + print("%d ops in the final graph." % len(output_graph.node)) + print("\n\n frozen graph saved to {}".format(output_path)) + + def get_output_nodes(self): + """ + :return: list of all output nodes + """ + ret = [ + self.prob_output, self.delta_output, + self.box2d_ind_after_nms, + self.predict_summary + ] + + return ret + + def get_output_nodes_names(self): + """ + :return: list of the names of all output nodes (string) + """ + nodes = self.get_output_nodes() + ret = [] + for nd in nodes: + ret.append(nd.name.split(":")[0]) + + print(ret) + return ret + + + def get_input_nodes(self): + """ + :return: list of all input nodes + """ + ret = [ + self.boxes2d, self.boxes2d_scores, + self.rgb, self.bv, self.bv_heatmap + ] + for idx in range(len(self.avail_gpus)): + ret.append(self.vox_feature[idx]) + ret.append(self.vox_number[idx]) + ret.append(self.vox_coordinate[idx]) + + return ret + def average_gradients(tower_grads): # ref: # https://github.com/tensorflow/models/blob/6db9f0282e2ab12795628de6200670892a8ad6ba/tutorials/image/cifar10/cifar10_multi_gpu_train.py#L103 @@ -355,7 +414,8 @@ def average_gradients(tower_grads): return average_grads + + if __name__ == '__main__': pass -