tutorial tensores programar imagenes examples español ejemplos con como clasificador basicos tensorflow

tensores - ¿Cómo seleccionar las filas de un Tensor 3-D en TensorFlow?



tensorflow python español (2)

Esto es posible en TensorFlow, pero un poco incómodo, ya que tf.gather() actualmente solo funciona con índices unidimensionales, y solo selecciona cortes de la dimensión 0 de un tensor. Sin embargo, aún es posible resolver su problema de manera eficiente, transformando los argumentos para que puedan pasarse a tf.gather() :

logits = ... # [2 x 4 x 4] tensor indices = tf.constant([[0, 1], [1, 3]]) # Use tf.shape() to make this work with dynamic shapes. batch_size = tf.shape(logits)[0] rows_per_batch = tf.shape(logits)[1] indices_per_batch = tf.shape(indices)[1] # Offset to add to each row in indices. We use `tf.expand_dims()` to make # this broadcast appropriately. offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1) # Convert indices and logits into appropriate form for `tf.gather()`. flattened_indices = tf.reshape(indices + offset, [-1]) flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]])) selected_rows = tf.gather(flattened_logits, flattened_indices) result = tf.reshape(selected_rows, tf.concat(0, [tf.pack([batch_size, indices_per_batch]), tf.shape(logits)[2:]]))

Tenga en cuenta que, dado que esto utiliza tf.reshape() y no tf.transpose() , no necesita modificar los datos (potencialmente grandes) en el tensor de logits , por lo que debería ser bastante eficiente.

Tengo un tensor logits con las dimensiones [batch_size, num_rows, num_coordinates] (es decir, cada logit en el lote es una matriz). En mi caso, el tamaño del lote es 2, hay 4 filas y 4 coordenadas.

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], [11.0, 10.0, 10.0, 30.0], [12.0, 10.0, 10.0, 20.0], [13.0, 10.0, 10.0, 20.0]], [[14.0, 11.0, 21.0, 31.0], [15.0, 11.0, 11.0, 21.0], [16.0, 11.0, 11.0, 21.0], [17.0, 11.0, 11.0, 21.0]]])

Quiero seleccionar la primera y la segunda fila del primer lote y la segunda y cuarta fila del segundo lote.

indices = tf.constant([[0, 1], [1, 3]])

Entonces, el resultado deseado sería

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], [11.0, 10.0, 10.0, 30.0]], [[15.0, 11.0, 11.0, 21.0], [17.0, 11.0, 11.0, 21.0]]])

¿Cómo hago esto usando TensorFlow? Intenté usar tf.gather(logits, indices) pero no devolvió lo que esperaba. ¡Gracias!


La respuesta de mrry es genial, pero creo que con la función tf.gather_nd el problema se puede resolver con muchas menos líneas de código (probablemente esta función aún no estaba disponible en el momento de la escritura de mrry):

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], [11.0, 10.0, 10.0, 30.0], [12.0, 10.0, 10.0, 20.0], [13.0, 10.0, 10.0, 20.0]], [[14.0, 11.0, 21.0, 31.0], [15.0, 11.0, 11.0, 21.0], [16.0, 11.0, 11.0, 21.0], [17.0, 11.0, 11.0, 21.0]]]) indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]]) result = tf.gather_nd(logits, indices) with tf.Session() as sess: print(sess.run(result))

Esto se imprimirá

[[[ 10. 10. 20. 20.] [ 11. 10. 10. 30.]] [[ 15. 11. 11. 21.] [ 17. 11. 11. 21.]]]

tf.gather_nd debería estar disponible a partir de v0.10. Mira este problema de github para más discusiones sobre esto.