python - example - ¿Para qué se usa el parámetro “max_q_size” en “model.fit_generator”?
model fit_generator keras 2 (1)
Construí un generador simple que produce una tuple(inputs, targets)
con solo elementos individuales en las listas de inputs
y targets
. Básicamente, está rastreando el conjunto de datos, un elemento de muestra a la vez.
Paso este generador en:
model.fit_generator(my_generator(),
nb_epoch=10,
samples_per_epoch=1,
max_q_size=1 # defaults to 10
)
Lo entiendo:
-
nb_epoch
es el número de veces que se ejecutará el lote de entrenamiento -
samples_per_epoch
es el número de muestras entrenadas con por época
Pero, ¿para qué es max_q_size
y por qué sería 10 como predeterminado? Pensé que el propósito de usar un generador era agrupar los conjuntos de datos en partes razonables, ¿por qué la cola adicional?
Esto simplemente define el tamaño máximo de la cola de entrenamiento interno que se utiliza para "precache" de sus muestras desde el generador. Se utiliza durante la generación de las colas.
def generator_queue(generator, max_q_size=10,
wait_time=0.05, nb_worker=1):
''''''Builds a threading queue out of a data generator.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
''''''
q = queue.Queue()
_stop = threading.Event()
def data_generator_task():
while not _stop.is_set():
try:
if q.qsize() < max_q_size:
try:
generator_output = next(generator)
except ValueError:
continue
q.put(generator_output)
else:
time.sleep(wait_time)
except Exception:
_stop.set()
raise
generator_threads = [threading.Thread(target=data_generator_task)
for _ in range(nb_worker)]
for thread in generator_threads:
thread.daemon = True
thread.start()
return q, _stop
En otras palabras, tiene un hilo que llena la cola hasta una capacidad máxima dada directamente desde su generador, mientras que (por ejemplo) la rutina de entrenamiento consume sus elementos (y algunas veces espera su finalización)
while samples_seen < samples_per_epoch:
generator_output = None
while not _stop.is_set():
if not data_gen_queue.empty():
generator_output = data_gen_queue.get()
break
else:
time.sleep(wait_time)
¿Y por qué por defecto de 10? No hay ninguna razón en particular, como la mayoría de los valores predeterminados: simplemente tiene sentido, pero también podría usar valores diferentes.
Una construcción como esta sugiere que los autores pensaron en los generadores de datos costosos, lo que podría llevar su tiempo de ejecución. Por ejemplo, considere la descarga de datos a través de una red en la llamada del generador; entonces, tiene sentido guardar en caché los siguientes lotes, y descargar los siguientes en paralelo por razones de eficiencia y ser robusto a los errores de red, etc.