examples - tensorflow neural network
AclaraciĆ³n sobre tf.Tensor.set_shape() (1)
Por lo que sé (y escribí ese código), no hay un error en Tensor.set_shape()
. Creo que el malentendido proviene del confuso nombre de ese método.
Para elaborar la entrada de Preguntas Frecuentes que citó , Tensor.set_shape()
es una función de Python puro que mejora la información de forma para un objeto tf.Tensor
dado. Por "mejora", quiero decir "hace más específico".
Por lo tanto, cuando tiene un objeto Tensor
t
con forma (?,)
, Es un tensor unidimensional de longitud desconocida. Puede llamar a t.set_shape((1028178,))
, y luego tendrá forma (1028178,)
cuando llame a t.get_shape()
. Esto no afecta al almacenamiento subyacente, ni a nada en el backend: simplemente significa que la inferencia de forma subsiguiente utilizando t
puede basarse en la afirmación de que es un vector de longitud 1028178.
Si t
tiene forma (?,)
, La llamada a t.set_shape((478, 717, 3))
fallará, porque TensorFlow ya sabe que t
es un vector, por lo que no puede tener forma (478, 717, 3)
. Si desea hacer un nuevo Tensor con esa forma a partir del contenido de t
, puede usar reshaped_t = tf.reshape(t, (478, 717, 3))
. Esto crea un nuevo objeto tf.Tensor
en Python; la implementación real de tf.reshape()
hace esto usando una copia superficial del búfer tensor, por lo que es barato en la práctica.
Una analogía es que Tensor.set_shape()
es como una Tensor.set_shape()
en tiempo de ejecución en un lenguaje orientado a objetos como Java. Por ejemplo, si tiene un puntero a un Object
pero sabe que, de hecho, es una String
, puede hacer el objeto (String) obj
para pasar obj
a un método que espere un argumento de String
. Sin embargo, si tiene un String
y trata de convertirlo en un java.util.Vector
, el compilador le dará un error, porque estos dos tipos no están relacionados.
Tengo una imagen de 478 x 717 x 3 = 1028178 píxeles, con un rango de 1. Lo verifiqué llamando a tf.shape y tf.rank.
Cuando llamo a image.set_shape ([478, 717, 3]), aparece el siguiente error.
"Shapes %s and %s must have the same rank" % (self, other))
ValueError: Shapes (?,) and (478, 717, 3) must have the same rank
Probé de nuevo con el primer lanzamiento a 1028178, pero el error aún existe.
ValueError: Shapes (1028178,) and (478, 717, 3) must have the same rank
Bueno, eso tiene sentido porque uno es de rango 1 y el otro es de rango 3. Sin embargo, ¿por qué es necesario lanzar un error, ya que el número total de píxeles aún coincide?
Por supuesto, podría usar tf.reshape y funciona, pero creo que no es óptimo.
Como se indica en las preguntas frecuentes de TensorFlow
¿Cuál es la diferencia entre x.set_shape () y x = tf.reshape (x)?
El método tf.Tensor.set_shape () actualiza la forma estática de un objeto Tensor, y normalmente se usa para proporcionar información adicional sobre la forma cuando esto no se puede inferir directamente. No cambia la forma dinámica del tensor.
La operación tf.reshape () crea un nuevo tensor con una forma dinámica diferente.
Crear un nuevo tensor implica la asignación de memoria y eso podría ser más costoso cuando se involucran más ejemplos de entrenamiento. ¿Es esto por diseño, o me estoy perdiendo algo aquí?