timctho / convolutional-pose-machines-tensorflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Help needed for generating frozen graph( protobuf (.pb) format) of the tensorflow checkpoints shared.

Suraj520 opened this issue · comments

commented

Dear @timctho ,
I want to generate a frozen graph using the Tensorflow checkpoint file that you have supplied. It would be great if you can help me.
I am using the folllowing code to do that .. However I am getting error if I use the following snippet .. I need help regarding the model argument(Highlighted below), i.e which model to use .. As if I use cpm_hand then it throws an error stating NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Code snippet for generating frozen graph

`import tensorflow as tf
import argparse
from networks import get_network
import os

from pprint import pprint

os.environ['CUDA_VISIBLE_DEVICES'] = ''
parser = argparse.ArgumentParser(description='Tensorflow Pose Estimation Graph Extractor')
parser.add_argument('--model', type=str, default='cpm_hand', help='')
parser.add_argument('--size', type=int, default=224)
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint path')
parser.add_argument('--output_node_names', type=str, default='Convolutional_Pose_Machine/stage_5_out')
parser.add_argument('--output_graph', type=str, default='./model.pb', help='output_freeze_path')

args = parser.parse_args()

input_node = tf.placeholder(tf.float32, shape=[1, args.size, args.size, 3], name="image")

with tf.Session() as sess:
net = get_network(args.model, input_node, trainable=False)
saver = tf.train.Saver()
saver.restore(sess, args.checkpoint)
input_graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, # The session
input_graph_def, # input_graph_def is useful for retrieving the nodes
args.output_node_names.split(",")
)

with tf.gfile.GFile(args.output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
`