OValery16 / Rename-Variable-in-a-graph---Tensorflow

Python script to rename variables in a graph - Tensorflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Rename Variable in a graph - Tensorflow

If you want to build a model that use a pretained Tensorflow model (for example, as feature extractor), you might encounter the following situation. Some names in the graph of your first model might be the same as the one of your second model. If your try to restore the weight of your first model, tensorflow will intent to update the wrong part of your final graph, causing in most of the cases some erros.

To solve this problem, the solution is often to processed to some variable renaming.

  1. You rename manually your first model its. For example, with tf.variable_scope('generator'): could become with tf.variable_scope('generator_Model1'):
  2. You rename the variable contained in the checkpoint of your first model, using the script tf_rename.py of this project. don't forget to update the code according your need. In the previous example, we get:
    checkpoint_dir = 'checkpoint_dir'
    replace_substr1 = 'generator'
    replace_substr2 = 'generator_Model1'
	prefix = ''

This script will update each variable accordingly.

  1. You can restore the weights of the first model from the code of your second model. For example.
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator_Model1')
weight_initiallizer = tf.train.Saver(var_list)

# Define the initialization operation
init_op = tf.global_variables_initializer()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
	# Load the pretrained model
	print('Loading weights from the pre-trained model')
	weight_initiallizer.restore(sess, FLAGS.checkpoint)

About

Python script to rename variables in a graph - Tensorflow


Languages

Language:Python 100.0%