redes - tensorflow tutorial
El entrenamiento de una red neuronal completamente convolucional con entradas de tamaƱo variable lleva demasiado tiempo en Keras/TensorFlow (0)
Estoy tratando de implementar un FCNN para la clasificación de imágenes que pueda aceptar entradas de tamaño variable. El modelo está construido en Keras con el backend TensorFlow.
Considere el siguiente ejemplo de juguete:
model = Sequential()
# width and height are None because we want to process images of variable size
# nb_channels is either 1 (grayscale) or 3 (rgb)
model.add(Convolution2D(32, 3, 3, input_shape=(nb_channels, None, None), border_mode=''same''))
model.add(Activation(''relu''))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(32, 3, 3, border_mode=''same''))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(16, 1, 1))
model.add(Activation(''relu''))
model.add(Convolution2D(8, 1, 1))
model.add(Activation(''relu''))
# reduce the number of dimensions to the number of classes
model.add(Convolution2D(nb_classses, 1, 1))
model.add(Activation(''relu''))
# do global pooling to yield one value per class
model.add(GlobalAveragePooling2D())
model.add(Activation(''softmax''))
Este modelo funciona bien, pero me encuentro con un problema de rendimiento. El entrenamiento en imágenes de tamaño variable lleva un tiempo irracionalmente largo en comparación con el entrenamiento en las entradas de tamaño fijo. Si cambio el tamaño de todas las imágenes al tamaño máximo en el conjunto de datos, aún se necesita mucho menos tiempo para entrenar el modelo que para entrenar en la entrada de tamaño variable. Entonces, ¿ input_shape=(nb_channels, None, None)
la forma correcta de especificar la entrada de tamaño variable? ¿Y hay alguna forma de mitigar este problema de rendimiento?
Actualizar
model.summary()
para un modelo con 3 clases e imágenes en escala de grises:
Layer (type) Output Shape Param # Connected to
====================================================================================================
convolution2d_1 (Convolution2D) (None, 32, None, None 320 convolution2d_input_1[0][0]
____________________________________________________________________________________________________
activation_1 (Activation) (None, 32, None, None 0 convolution2d_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D) (None, 32, None, None 0 activation_1[0][0]
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D) (None, 32, None, None 9248 maxpooling2d_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_2 (MaxPooling2D) (None, 32, None, None 0 convolution2d_2[0][0]
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D) (None, 16, None, None 528 maxpooling2d_2[0][0]
____________________________________________________________________________________________________
activation_2 (Activation) (None, 16, None, None 0 convolution2d_3[0][0]
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D) (None, 8, None, None) 136 activation_2[0][0]
____________________________________________________________________________________________________
activation_3 (Activation) (None, 8, None, None) 0 convolution2d_4[0][0]
____________________________________________________________________________________________________
convolution2d_5 (Convolution2D) (None, 3, None, None) 27 activation_3[0][0]
____________________________________________________________________________________________________
activation_4 (Activation) (None, 3, None, None) 0 convolution2d_5[0][0]
____________________________________________________________________________________________________
globalaveragepooling2d_1 (Global (None, 3) 0 activation_4[0][0]
____________________________________________________________________________________________________
activation_5 (Activation) (None, 3) 0 globalaveragepooling2d_1[0][0]
====================================================================================================
Total params: 10,259
Trainable params: 10,259
Non-trainable params: 0