python - network - tensorflow models
Diferencia entre Variable y get_variable en TensorFlow (4)
Hasta donde yo sé,
Variable
es la operación predeterminada para hacer una variable, y
get_variable
se usa principalmente para compartir peso.
Por un lado, hay algunas personas que sugieren usar
get_variable
lugar de la operación
Variable
primitiva siempre que necesite una variable.
Por otro lado, simplemente veo cualquier uso de
get_variable
en los documentos y demostraciones oficiales de TensorFlow.
Por lo tanto, quiero conocer algunas reglas generales sobre cómo usar correctamente estos dos mecanismos. ¿Hay algún principio "estándar"?
Otra diferencia radica en que uno está en la colección
(''variable_store'',)
pero el otro no.
Por favor vea el código code :
def _get_default_variable_store():
store = ops.get_collection(_VARSTORE_KEY)
if store:
return store[0]
store = _VariableStore()
ops.add_to_collection(_VARSTORE_KEY, store)
return store
Déjame ilustrarte eso:
import tensorflow as tf
from tensorflow.python.framework import ops
embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32)
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])
graph = tf.get_default_graph()
collections = graph.collections
for c in collections:
stores = ops.get_collection(c)
print(''collection %s: '' % str(c))
for k, store in enumerate(stores):
try:
print(''/t%d: %s'' % (k, str(store._vars)))
except:
print(''/t%d: %s'' % (k, str(store)))
print('''')
La salida:
collection (''__variable_store'',): 0: {''word_embeddings_2'': <tf.Variable ''word_embeddings_2:0'' shape=(30522, 1024) dtype=float32_ref>}
Puedo encontrar dos diferencias principales entre una y otra:
-
Primero es que
tf.Variable
siempre creará una nueva variable, ya sea quetf.get_variable
obtenga del gráfico una variable existente con esos parámetros, y si no existe, crea una nueva. -
tf.Variable
requiere que se especifique un valor inicial.
Es importante aclarar que la función
tf.get_variable
prefija el nombre con el alcance de la variable actual para realizar verificaciones de reutilización.
Por ejemplo:
with tf.variable_scope("one"):
a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
c = tf.get_variable("v", [1]) #c.name == "one/v:0"
with tf.variable_scope("two"):
d = tf.get_variable("v", [1]) #d.name == "two/v:0"
e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"
assert(a is c) #Assertion is true, they refer to the same object.
assert(a is d) #AssertionError: they are different objects
assert(d is e) #AssertionError: they are different objects
El último error de aserción es interesante: se supone que dos variables con el mismo nombre bajo el mismo alcance son la misma variable.
Pero si prueba los nombres de las variables
d
se dará cuenta de que Tensorflow cambió el nombre de la variable
e
:
d.name #d.name == "two/v:0"
e.name #e.name == "two/v_1:0"
Recomiendo usar siempre
tf.get_variable(...)
: facilitará la refactorización de su código si necesita compartir variables en cualquier momento, por ejemplo, en una configuración multi-gpu (consulte el multi-gpu CIFAR ejemplo).
No hay inconveniente en ello.
La
tf.Variable
pura es de nivel inferior;
en algún momento no existía
tf.get_variable()
, por lo que algunos códigos todavía usan la forma de bajo nivel.
tf.Variable es una clase, y hay varias formas de crear tf.Variable, incluyendo tf.Variable .__ init__ y tf.get_variable.
tf.Variable .__ init__: crea una nueva variable con valor_inicial .
W = tf.Variable(<initial-value>, name=<optional-name>)
tf.get_variable: Obtiene una variable existente con estos parámetros o crea una nueva. También puedes usar initializer.
W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
regularizer=None, trainable=True, collections=None)
Es muy útil utilizar inicializadores como xavier_initializer:
W = tf.get_variable("W", shape=[784, 256],
initializer=tf.contrib.layers.xavier_initializer())
Más información en https://www.tensorflow.org/versions/r0.8/api_docs/python/state_ops.html#Variable .