python - reconocimiento - Cómo elegir los últimos valores de salida válidos de tensorflow RNN
tensorflow reconocimiento de imagen (3)
Danijar publicó una solución más aceptable en la página de solicitud de funciones que he vinculado en la pregunta. No es necesario evaluar los tensores, lo cual es una gran ventaja.
Lo tengo para trabajar con tensorflow 0.8. Aquí está el código:
def extract_last_relevant(outputs, length):
"""
Args:
outputs: [Tensor(batch_size, output_neurons)]: A list containing the output
activations of each in the batch for each time step as returned by
tensorflow.models.rnn.rnn.
length: Tensor(batch_size): The used sequence length of each example in the
batch with all later time steps being zeros. Should be of type tf.int32.
Returns:
Tensor(batch_size, output_neurons): The last relevant output activation for
each example in the batch.
"""
output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2])
# Query shape.
batch_size = tf.shape(output)[0]
max_length = int(output.get_shape()[1])
num_neurons = int(output.get_shape()[2])
# Index into flattened array as a workaround.
index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, num_neurons])
relevant = tf.gather(flat, index)
return relevant
Estoy entrenando una célula LSTM en lotes de secuencias que tienen diferentes longitudes. El tf.nn.rnn
tiene el parámetro muy conveniente sequence_length
, pero después de llamarlo, no sé cómo seleccionar las filas de salida correspondientes al último paso de tiempo de cada elemento en el lote.
Mi código es básicamente el siguiente:
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
lstm_outputs
es una lista con la salida LSTM en cada paso de tiempo. Sin embargo, cada elemento de mi lote tiene una longitud diferente, por lo que me gustaría crear un tensor que contenga la última salida LSTM válida para cada elemento de mi lote.
Si pudiera usar la indexación numpy, simplemente haría algo como esto:
all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
Pero resulta que, por el momento, tensorflow no lo admite (conozco la solicitud de funciones ).
Entonces, ¿cómo podría obtener estos valores?
No es la mejor solución, pero podrías evaluar tus resultados y luego utilizar la indexación numpy para obtener los resultados y crear una variable de tensor a partir de eso. Podría funcionar como un espacio de detención hasta que tensorflow obtenga esta característica. p.ej
all_outputs = session.run(lstm_outputs, feed_dict={''your inputs''})
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs)
si solo está interesado en la última salida válida, puede recuperarla a través del estado devuelto por tf.nn.rnn()
teniendo en cuenta que siempre es una tupla (c, h) donde c es el último estado y h es la última salida . Cuando el estado es una LSTMStateTuple
, puede usar el siguiente fragmento de código (que trabaja en tensorflow 0.12):
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
last_output = state[1]