examples - TensorFlow: rendimiento lento al obtener gradientes en las entradas
tensorflow playground (1)
Estoy construyendo un simple perceptrón multicapa con TensorFlow, y también necesito obtener los gradientes (o señal de error) de la pérdida en las entradas de la red neuronal.
Aquí está mi código, que funciona:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
...
for i in range(epochs):
....
for batch in batches:
...
sess.run(optimizer, feed_dict=feed_dict)
grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0]
(editado para incluir bucle de entrenamiento)
Sin la última línea (
grads_wrt_input...
), esto se ejecuta muy rápido en una máquina CUDA.
Sin embargo,
tf.gradients()
reduce el rendimiento en diez veces o más.
Recuerdo que las señales de error en los nodos se calculan como valores intermedios en el algoritmo de retropropagación, y lo he logrado con éxito utilizando la biblioteca Java DeepLearning4j.
También tenía la impresión de que esto sería una ligera modificación en el gráfico de cálculo ya construido por el
optimizer
.
¿Cómo se puede hacer esto más rápido, o hay alguna otra forma de calcular los gradientes de la pérdida con las entradas?
La función
tf.gradients()
crea un nuevo gráfico de retropropagación cada vez que se llama, por lo que la razón de la desaceleración es que TensorFlow tiene que analizar un nuevo gráfico en cada iteración del bucle.
(Esto puede ser sorprendentemente costoso: la versión actual de TensorFlow está optimizada para ejecutar el
mismo
gráfico muchas veces).
Afortunadamente, la solución es fácil: solo calcule los gradientes una vez, fuera del ciclo. Puede reestructurar su código de la siguiente manera:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
grads_wrt_input_tensor = tf.gradients(cost, self.x)[0]
# ...
for i in range(epochs):
# ...
for batch in batches:
# ...
_, grads_wrt_input = sess.run([optimizer, grads_wrt_input_tensor],
feed_dict=feed_dict)
Tenga en cuenta que, para el rendimiento, también combiné las dos llamadas
sess.run()
.
Esto asegura que la propagación hacia adelante y gran parte de la propagación hacia atrás se reutilizarán.
Por otro lado, un consejo para encontrar errores de rendimiento como este es llamar a
tf.get_default_graph().finalize()
antes de comenzar su ciclo de entrenamiento.
Esto generará una excepción si accidentalmente agrega nodos al gráfico, lo que facilita el seguimiento de la causa de estos errores.