tf.Variable() and tf.get_variable()

函数定义

tf.Variable

# tf.Variable
__init__(
    initial_value=None,
    trainable=True,
    collections=None,
    validate_shape=True,
    caching_device=None,
    name=None,
    variable_def=None,
    dtype=None,
    expected_shape=None,
    import_scope=None,
    constraint=None,
    use_resource=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.VariableAggregation.NONE
)

tf.get_variable():

tf.get_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    collections=None,
    caching_device=None,
    partitioner=None,
    validate_shape=True,
    use_resource=None,
    custom_getter=None,
    constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.VariableAggregation.NONE
)

从一段代码说起

#......
var1 = tf.Variable(1,name='v')
var2 = tf.get_variable('v', [1])
#......

这段代码中,先使用 tf.Variable() 生成了一个变量,该变量用常量1进行初始化,并且name='v'。使用tf.Variable()时,如果发生命名冲突,tf.Variable()会自行处理,处理方式为:使用_1, _2这样的后缀来避免冲突。

tf.get_variable()也是用来生成一个新的变量或者获取一个已存在的变量。在这里,会创建一个新变量v_1
tf.get_variable()的具体行为与tf.variable_scope()中的reuse有关:

  • reuse=None:默认情况,若变量存在(使用tf.get_variable()创建),则返回已存在变量;若变量不存在,创建之;
  • reuse=True:使用变量共享,这样tf.get_variable()的功能就是获取一个共享变量,若该变量不存在,报错之;
  • reuse=False:关闭变量共享,这样tf.get_variable()的功能就是创建一个变量,若变量存在,通过添加后缀(_1, _2, ....)来避免命名冲突。

这里需要注意的是:

  1. tf.Variable()生成的变量,不可以使用tf.get_variable()来进行变量共享;
  2. tf.get_variable()tf.variable_scope()配合使用,但不受tf.name_scope()约束;
  3. 编程实践中,建议全部使用tf.get_variable()来维护变量;

tf.Variable会在Graph中生成很多Node(v/initial_value, v, v/Assign,......)。需要注意的Node是v

node {
  name: "v"
  op: "VariableV2"
  attr {
    key: "container"
    value {
      s: ""
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
  attr {
    key: "shared_name"
    value {
      s: ""
    }
  }
}

tf.get_variable() 也会生成很多对应节点,我们只看v_1

node {
  name: "v_1"
  op: "VariableV2"
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@v_1"
      }
    }
  }
  attr {
    key: "container"
    value {
      s: ""
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 1
        }
      }
    }
  }
  attr {
    key: "shared_name"
    value {
      s: ""
    }
  }
}

对比tf.Variable()生成的vtf.get_variable()生成的v_1,会发现,tf.get_variable()生成的v_1中多了一个属性:

attr {
    key: "_class"
    value {
      list {
        s: "loc:@v_1"
      }
    }
  }

该属性的作用是在同一个设备上协同定位结点,事实上,凡是具有class属性的节点,都会被放在同一设备上进行维护。我想就是因为tf.get_variable()在graph中生成节点的时候就知道了该节点分配的位置,所以方便了变量共享,而对于tf.Variable()则没有这样的属性,因此在未运行会话前,是不知道变量位置的;因此也就无法完成变量共享。

总结

  1. 使用tf.get_variable()维护变量;
  2. 使用tf.get_variable()维护的变量,不受tf.name_scope()约束;所以,推荐使用tf.variable_scope()进行命名空间维护。