python - train - TensorFlow Object Detection API Comportamiento extraño
tensorflow object detection tutorial (3)
¿Cuántas imágenes hay en el conjunto de datos? Cuantos más datos de entrenamiento tengas, mejor será el rendimiento de la API. Intenté entrenarlo en unas 20 imágenes por clase, la precisión era bastante mala. Casi enfrenté todos los problemas que mencionaste anteriormente. Cuando generé más datos, la precisión mejoró considerablemente.
PD: Lo siento, no pude comentar ya que no tengo suficiente reputación.
Estaba jugando con la nueva API de detección de objetos de TensorFlow y decidí entrenarla en otros conjuntos de datos disponibles públicamente.
Me encontré con this conjunto de datos de abarrotes que consiste en imágenes de varias marcas de cajas de cigarrillos en el estante del supermercado junto con un archivo de texto que enumera los cuadros delimitadores de cada caja de cigarrillos en cada imagen. 10 marcas principales se han etiquetado en el conjunto de datos y todas las demás marcas se encuentran en la categoría 11 "miscelánea".
Seguí su tutorial y logré entrenar el modelo en este conjunto de datos. Debido a las limitaciones en la capacidad de procesamiento, utilicé solo un tercio del conjunto de datos y realicé una división de 70:30 para entrenamiento y pruebas de datos. Utilicé el modelo faster_rcnn_resnet101. Todos los parámetros en mi archivo de configuración son los mismos que los parámetros predeterminados proporcionados por TF.
Después de 16491 pasos globales, probé el modelo en algunas imágenes pero no estoy muy contento con los resultados.
Error al detectar los camellos en el estante superior, mientras que detecta el producto en otras imágenes
¿Por qué no se detectan los Marlboros en la fila superior?
Otro problema que tuve es que el modelo nunca detectó ninguna otra etiqueta, excepto la etiqueta 1
No se ha detectado una instancia de recorte del producto a partir de los datos de entrenamiento.
¡Detecta cajas de cigarrillos con un 99% de confianza incluso en imágenes negativas!
¿Puede alguien ayudarme con lo que está mal? ¿Qué puedo hacer para mejorar la precisión? ¿Y por qué detecta que todos los productos pertenecen a la categoría 1, aunque he mencionado que hay 11 clases en total?
Editar Añadido el mapa de mi etiqueta:
item {
id: 1
name: ''1''
}
item {
id: 2
name: ''2''
}
item {
id: 3
name: ''3''
}
item {
id: 4
name: ''4''
}
item {
id: 5
name: ''5''
}
item {
id: 6
name: ''6''
}
item {
id: 7
name: ''7''
}
item {
id: 8
name: ''8''
}
item {
id: 9
name: ''9''
}
item {
id: 10
name: ''10''
}
item {
id: 11
name: ''11''
}
Así que creo que me di cuenta de lo que está pasando. Hice un análisis en el conjunto de datos y descubrí que está sesgado hacia objetos de la categoría 1.
Esta es la distribución de frecuencia de cada categoría de 1 a 11 (en indexación basada en 0)
0 10440
1 304
2 998
3 67
4 412
5 114
6 190
7 311
8 195
9 78
10 75
Supongo que el modelo está alcanzando un mínimo local en el que etiquetar todo como categoría 1 es suficiente.
Sobre el problema de no detectar algunas cajas: Intenté entrenar nuevamente, pero esta vez no diferencié entre marcas. En su lugar, traté de enseñarle al modelo qué es una caja de cigarrillos. Todavía no estaba detectando todas las cajas.
Entonces decidí recortar la imagen de entrada y proporcionarla como entrada. Solo para ver si los resultados mejoran y lo hizo!
Resulta que las dimensiones de la imagen de entrada eran mucho más grandes que las 600 x 1024 aceptadas por el modelo. Entonces, fue reduciendo estas imágenes a 600 x 1024, lo que significaba que las cajas de cigarrillos estaban perdiendo sus detalles :)
Entonces, decidí probar el modelo original que fue entrenado en todas las clases en imágenes recortadas y funciona como un encanto :)
Esta fue la salida del modelo en la imagen original.
Esta es la salida del modelo cuando recorte el cuarto superior izquierdo y lo proporciono como entrada.
Gracias a todos los que ayudaron! Y felicidades al equipo de TensorFlow por un trabajo increíble para la API :) ¡Ahora todos pueden entrenar modelos de detección de objetos!
Parece que el tamaño del conjunto de datos es bastante pequeño. Resnet es una red grande, que requerirá aún más datos para entrenar adecuadamente.
Qué hacer:
- Aumentar el tamaño del conjunto de datos
- Use redes pre-entrenadas y ajuste su base de datos (probablemente ya lo haga)
- Use el aumento de datos (cambio de tamaño, desenfoque, ...; voltear puede no ser apropiado para este conjunto de datos).