Variables in Tensorflow

Today, we talk about how to use and manage Variables in Tensorflow. There are so many variables in neural network that it's challenging to use and manage them. Sometimes, we share same variables in one neural network model. Let's show how to fingure out these issues in Tensorflow.

I will talk about them in three aspects :

  1. Variable Use in Tensorflow
  2. Variable management in Tensorflow
  3. Variable Save and Restore in Tensorflow

一. Variable Use in Tensorflow

Variables mean parameters,such as weights, Convolution kernels, bias etc, in Convoltion neural network. Parameters will be updated to make CNN perform better when We train a model using gradient descent algorithm.

There are two main functions to create a new variable : tf.Variable() and tf.get_variable().

1.tf.Variable()

We can create a new variable using tf.Variable(), and we should notice that we can use same name to define different variables, because tensorflow will check your variable name and modify it if it has been named, for example, tensorflow will use "variable" to name an unnamed varibale and use "variable_1" to name next unnamed variable. So tensorflow solved variable name conflict. Let's see the code below:

import tensorflow as tf

var = tf.Variable(1)
var1 = tf.Variable(2)
var2 = tf.Variable(3, name='foo')
var3 = tf.Variable(4, name='foo')

This program defines four varibles : var, var1, var2, var3.They are just the names that we use in program but not in tensorflow's procedure. var is named as "Variable", var1 is named as "Variable_1",var2 is named as "foo",var3 is named as "foo_1" in Tensorflow.
In next lines, we can use these variables to do anything you want and tensorflow will manage them for you in Graph and session.

Note: we don't need consider name conflict, because tensorflow will take care of it.

2.tf.get_variable()

We have another choice to create a variable in tensorflow :tf.get_variable(). First, let's see how to use this function.

tf.get_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    collections=None,
    caching_device=None,
    partitioner=None,
    validate_shape=True,
    use_resource=None,
    custom_getter=None,
    constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.VariableAggregation.NONE
)

We can use it like this :

import tensorflow as tf

var = tf.get_variable('var', [1])

We can use tf.get_variable() to get an existing variable or use these paramaters to create a new one. It's important to distinguish tf.Variable() and tf.get_variable().

tf.get_variable() is a high level version of tf.Varible(). These two functions' diffences will show up when we need share same variables or use scope to manage variables. I will go into details later.

二. Variable management in Tensorflow

1. scope

Deeper nerual network has more paramaters. More paramaters mean more difficult to manage variables. Tensorflow presents scope to efficiently manage variables in a CNN. There are two main classes involved in : tf.name_scope() and tf.variable_scope().

These two functions all will create a context manager,but the differences are below:

  • name_scope affects all ops except variables created with get_variable();
  • variable_scope will add scope as a prefix to all operations and variables.

Let us see code below:

import tensorflow as tf
def scoping(fn, scope1, scope2, vals):
    with fn(scope1):
        a = tf.Variable(vals[0], name='a')
        b = tf.get_variable('b', initializer=vals[1])
        c = tf.constant(vals[2], name='c')
        with fn(scope2):
            d = tf.add(a * b, c, name='res')

        print '\n  '.join([scope1, a.name, b.name, c.name, d.name]), '\n'
    return d

d1 = scoping(tf.variable_scope, 'scope_vars', 'res', [1, 2, 3])
d2 = scoping(tf.name_scope,     'scope_name', 'res', [1, 2, 3])

with tf.Session() as sess:
    writer = tf.summary.FileWriter('logs', sess.graph)
    sess.run(tf.global_variables_initializer())
    print sess.run([d1, d2])
    writer.close()

We define a function named scoping() and input different paramaters to see output:

scope_vars
  scope_vars/a:0
  scope_vars/b:0
  scope_vars/c:0
  scope_vars/res/res:0

scope_name
  scope_name/a:0
  b:0
  scope_name/c:0
  scope_name/res/res:0

If we show them on tensorboard we get this picture:

2. share same variables

  • We should use tf.get_variable()(but not tf.Variable()) to create variables that we reuse in next lines.
  • If we create a variable by tf.variable(),we can not reuse it by tf.get_variable().
  • We should usetf.variable_scope() to manage variables that we reuse in next lines.

三. Variable Save and Restore in Tensorflow

It is necessary to save variables when we get a satisfactory model and restore same variables when we want to use them.Tensorflow provides tf.train.Saver() to do this job.

  • tf.train.Saver().save()
  • tf.train.Saver().resore()

1. Saving A Model

import tensorflow as tf
a = tf.get_variable('a', [])
b = tf.get_variable('b', [])
init = tf.global_variables_initializer()

saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
saver.save(sess, './tftcp.model')

It will output four new files:

checkpoint
tftcp.model.data-00000-of-00001
tftcp.model.index
tftcp.model.meta
  • tftcp.model.data-00000-of-00001 contains the weights of your model (the first bullet point from above). It’s most likely the largest file here.
  • tftcp.model.meta is the network structure of your model (the second bullet point from above). It contains all the information needed to re-create your graph.
  • tftcp.model.index is an indexing structure linking the first two things. It says “where in the data file do I find the parameters corresponding to this node?”
  • checkpoint is not actually needed to reconstruct your model, but if you save multiple versions of your model throughout a training run, it keeps track of everything.

2. Loading a Model

import tensorflow as tf
a = tf.get_variable('a', [])
b = tf.get_variable('b', [])

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, './tftcp.model')
sess.run([a,b])