variable get_variable tensorflow

get_variable - TensorFlow: obteniendo variable por nombre



tf get_variable (2)

La forma más fácil de obtener una variable por nombre es buscarla en la colección tf.global_variables() :

var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]

Esto funciona bien para la reutilización ad hoc de variables existentes. En el tutorial de Compartir variables, se trata de un enfoque más estructurado, por ejemplo, cuando desea compartir variables entre varias partes de un modelo.

Al usar la API TensorFlow Python, creé una variable (sin especificar su name en el constructor) y su propiedad de name tenía el valor "Variable_23:0" . Cuando trato de seleccionar esta variable usando tf.get_variable("Variable23") , se crea una nueva variable llamada "Variable_23_1:0" su lugar. ¿Cómo selecciono correctamente "Variable_23" lugar de crear uno nuevo?

Lo que quiero hacer es seleccionar la variable por nombre, y reiniciarla para que pueda ajustar los pesos.


La función get_variable() crea una nueva variable o devuelve una creada anteriormente por get_variable() . No devolverá una variable creada usando tf.Variable() . Aquí hay un ejemplo rápido:

>>> with tf.variable_scope("foo"): ... bar1 = tf.get_variable("bar", (2,3)) # create ... >>> with tf.variable_scope("foo", reuse=True): ... bar2 = tf.get_variable("bar") # reuse ... >>> with tf.variable_scope("", reuse=True): # root variable scope ... bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above) ... >>> (bar1 is bar2) and (bar2 is bar3) True

Si no creó la variable usando tf.get_variable() , tiene un par de opciones. Primero, puede usar tf.global_variables() (como @mrry sugiere):

>>> bar1 = tf.Variable(0.0, name="bar") >>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0] >>> bar1 is bar2 True

O puede usar tf.get_collection() así:

>>> bar1 = tf.Variable(0.0, name="bar") >>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0] >>> bar1 is bar2 True

Editar

También puedes usar get_tensor_by_name() :

>>> bar1 = tf.Variable(0.0, name="bar") >>> graph = tf.get_default_graph() >>> bar2 = graph.get_tensor_by_name("bar:0") >>> bar1 is bar2 False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal bar2 in value.

Recuerde que un tensor es el resultado de una operación. Tiene el mismo nombre que la operación, más :0 . Si la operación tiene salidas múltiples, tienen el mismo nombre que la operación más :0 :1 :2 , y así sucesivamente.