saver - tensorflow save model after training
¿Qué es una variable local en tensorflow? (3)
Tensorflow tiene esta API definida:
tf.local_variables()
Devuelve todas las variables creadas con la
collection=[LOCAL_VARIABLES]
.Devoluciones:
Una lista de objetos variables locales.
¿Qué es exactamente una variable local en TensorFlow? ¿Puede alguien darme un ejemplo?
Creo que aquí se requiere la comprensión de las colecciones de TensorFlow.
TensorFlow proporciona colecciones, que se denominan listas de tensores u otros objetos, como tf.Variable
instancias.
Las siguientes son colecciones en construcción:
tf.GraphKeys.GLOBAL_VARIABLES #=> ''variables''
tf.GraphKeys.LOCAL_VARIABLES #=> ''local_variables''
tf.GraphKeys.MODEL_VARIABLES #=> ''model_variables''
tf.GraphKeys.TRAINABLE_VARIABLES #=> ''trainable_variables''
En general, en el momento de la creación de una variable, se puede agregar a la colección dada pasando explícitamente esa colección como una de las colecciones pasadas al argumento de collections
.
En teoría, una variable puede estar en cualquier combinación de colecciones incorporadas o personalizadas. Pero, las colecciones en construcción se utilizan para fines particulares:
-
tf.GraphKeys.GLOBAL_VARIABLES
:- El constructor de
Variable()
oget_variable()
agrega automáticamente nuevas variables a la colecciónGraphKeys.GLOBAL_VARIABLES
del gráfico, a menos que el argumento decollections
se pase explícitamente y no incluyaGLOBAL_VARIABLE
. - Por convención, estas variables se comparten en entornos distribuidos (las variables del modelo son un subconjunto de éstas).
- Vea
tf.global_variables()
para más detalles.
- El constructor de
-
tf.GraphKeys.TRAINABLE_VARIABLES
:- Cuando se pasa
trainable=True
(que es el comportamiento predeterminado), el constructorVariable()
yget_variable()
agregan automáticamente nuevas variables a esta colección de gráficos. Pero, por supuesto, puede usar el argumento decollections
para agregar una variable a cualquier colección deseada. - Por convención, estas son las variables que serán entrenadas por un optimizador.
- Vea
tf.trainable_variables()
para más detalles.
- Cuando se pasa
-
tf.GraphKeys.LOCAL_VARIABLES
:- Puede usar
tf.contrib.framework.local_variable()
para agregar a esta colección. Pero, por supuesto, puede usar el argumento decollections
para agregar una variable a cualquier colección deseada. - Por convención, estas son las variables que son locales para cada máquina. Son variables de proceso, generalmente no guardadas / restauradas en el punto de control y utilizadas para valores temporales o intermedios. Por ejemplo, se pueden usar como contadores para el cálculo de métricas o el número de épocas que esta máquina ha leído datos.
- Vea tf.local_variables para más detalles.
- Puede usar
-
tf.GraphKeys.MODEL_VARIABLES
:- Puede usar
tf.contrib.framework.model_variable()
para agregar a esta colección. Pero, por supuesto, puede usar el argumento decollections
para agregar una variable a cualquier colección deseada. - Por convención, estas son las variables que se utilizan en el modelo para la inferencia (feed forward).
- Vea
tf.model_variables()
para más detalles.
- Puede usar
También puedes utilizar tus propias colecciones. Cualquier cadena es un nombre de colección válido, y no hay necesidad de crear explícitamente una colección. Para agregar una variable (o cualquier otro objeto) a una colección después de crear la variable, llame a tf.add_to_collection()
.
Por ejemplo,
tf.__version__ #=> ''1.9.0''
# initializing using a Tensor
my_variable01 = tf.get_variable("var01", dtype=tf.int32, initializer=tf.constant([23, 42]))
# initializing using a convenient initializer
my_variable02 = tf.get_variable("var02", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.zeros_initializer)
my_variable03 = tf.get_variable("var03", dtype=tf.int32, initializer=tf.constant([1, 2]), trainable=None)
my_variable04 = tf.get_variable("var04", dtype=tf.int32, initializer=tf.constant([3, 4]), trainable=False)
my_variable05 = tf.get_variable("var05", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.ones_initializer, trainable=True)
my_variable06 = tf.get_variable("var06", dtype=tf.int32, initializer=tf.constant([5, 6]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=None)
my_variable07 = tf.get_variable("var07", dtype=tf.int32, initializer=tf.constant([7, 8]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=True)
my_variable08 = tf.get_variable("var08", dtype=tf.int32, initializer=tf.constant(1), collections=[tf.GraphKeys.MODEL_VARIABLES], trainable=None)
my_variable09 = tf.get_variable("var09", dtype=tf.int32, initializer=tf.constant(2), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES, "my_collectio
n"])
my_variable10 = tf.get_variable("var10", dtype=tf.int32, initializer=tf.constant(3), collections=["my_collection"], trainable=True)
[var.name for var in tf.global_variables()] #=> [''var01:0'', ''var02:0'', ''var03:0'', ''var04:0'', ''var05:0'', ''var09:0'']
[var.name for var in tf.local_variables()] #=> [''var06:0'', ''var07:0'', ''var09:0'']
[var.name for var in tf.trainable_variables()] #=> [''var01:0'', ''var02:0'', ''var05:0'', ''var07:0'', ''var09:0'', ''var10:0'']
[var.name for var in tf.model_variables()] #=> [''var08:0'', ''var09:0'']
[var.name for var in tf.get_collection("trainable_variables")] #=> [''var01:0'', ''var02:0'', ''var05:0'', ''var07:0'', ''var09:0'', ''var10:0'']
[var.name for var in tf.get_collection("my_collection")] #=> [''var09:0'', ''var10:0'']
Es igual que una variable regular, pero está en una colección diferente a la predeterminada ( GraphKeys.VARIABLES
). El ahorrador utiliza esa colección para inicializar la lista predeterminada de variables para guardar, por lo que tener una designación local
tiene el efecto de no guardar esa variable de forma predeterminada.
Sólo veo un lugar que lo usa en el código base, que es el limit_epochs
with ops.name_scope(name, "limit_epochs", [tensor]) as name:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
epochs = variables.Variable(
zero64, name="epochs", trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
Respuesta corta : una variable local en TF es cualquier variable que fue creada con collections=[tf.GraphKeys.LOCAL_VARIABLES]
. Por ejemplo:
e = tf.Variable(6, name=''var_e'', collections=[tf.GraphKeys.LOCAL_VARIABLES])
LOCAL_VARIABLES: el subconjunto de objetos Variables que son locales para cada máquina. Generalmente se usa para variables temporales, como contadores. Nota: use tf.contrib.framework.local_variable para agregar a esta colección.
Por lo general, no se guardan / restauran en el punto de control y se usan para valores temporales o intermedios.
Respuesta larga: esto también fue una fuente de confusión para mí. Al principio pensé que las variables locales significan lo mismo que las variables locales en casi cualquier lenguaje de programación , pero no es lo mismo:
import tensorflow as tf
def some_func():
z = tf.Variable(1, name=''var_z'')
a = tf.Variable(1, name=''var_a'')
b = tf.get_variable(''var_b'', 2)
with tf.name_scope(''aaa''):
c = tf.Variable(3, name=''var_c'')
with tf.variable_scope(''bbb''):
d = tf.Variable(3, name=''var_d'')
some_func()
some_func()
print [str(i.name) for i in tf.global_variables()]
print [str(i.name) for i in tf.local_variables()]
No importa lo que intenté, siempre recibía solo variables globales:
[''var_a:0'', ''var_b:0'', ''aaa/var_c:0'', ''bbb/var_d:0'', ''var_z:0'', ''var_z_1:0'']
[]
La documentación para tf.local_variables
no ha proporcionado muchos detalles:
Variables locales: por variables de proceso, generalmente no se guardan / restauran en el punto de control y se usan para valores temporales o intermedios. Por ejemplo, se pueden usar como contadores para el cálculo de métricas o el número de épocas que esta máquina ha leído datos. La variable local_variable () agrega automáticamente una nueva variable a GraphKeys.LOCAL_VARIABLES. Esta función de conveniencia devuelve el contenido de esa colección.
Pero al leer documentos para el método init en la clase tf.Variable, descubrí que al crear una variable, puedes proporcionar qué tipo de variable quieres que sea asignando una lista de collections
.
La lista de posibles elementos de colección está here . Entonces para crear una variable local necesitas hacer algo como esto. Lo verás en la lista de local_variables
:
e = tf.Variable(6, name=''var_e'', collections=[tf.GraphKeys.LOCAL_VARIABLES])
print [str(i.name) for i in tf.local_variables()]