start fit_generator fit example epochs classes batch python machine-learning generator keras

python - example - ¿Para qué se usa el parámetro “max_q_size” en “model.fit_generator”?



model fit_generator keras 2 (1)

Construí un generador simple que produce una tuple(inputs, targets) con solo elementos individuales en las listas de inputs y targets . Básicamente, está rastreando el conjunto de datos, un elemento de muestra a la vez.

Paso este generador en:

model.fit_generator(my_generator(), nb_epoch=10, samples_per_epoch=1, max_q_size=1 # defaults to 10 )

Lo entiendo:

  • nb_epoch es el número de veces que se ejecutará el lote de entrenamiento
  • samples_per_epoch es el número de muestras entrenadas con por época

Pero, ¿para qué es max_q_size y por qué sería 10 como predeterminado? Pensé que el propósito de usar un generador era agrupar los conjuntos de datos en partes razonables, ¿por qué la cola adicional?


Esto simplemente define el tamaño máximo de la cola de entrenamiento interno que se utiliza para "precache" de sus muestras desde el generador. Se utiliza durante la generación de las colas.

def generator_queue(generator, max_q_size=10, wait_time=0.05, nb_worker=1): ''''''Builds a threading queue out of a data generator. Used in `fit_generator`, `evaluate_generator`, `predict_generator`. '''''' q = queue.Queue() _stop = threading.Event() def data_generator_task(): while not _stop.is_set(): try: if q.qsize() < max_q_size: try: generator_output = next(generator) except ValueError: continue q.put(generator_output) else: time.sleep(wait_time) except Exception: _stop.set() raise generator_threads = [threading.Thread(target=data_generator_task) for _ in range(nb_worker)] for thread in generator_threads: thread.daemon = True thread.start() return q, _stop

En otras palabras, tiene un hilo que llena la cola hasta una capacidad máxima dada directamente desde su generador, mientras que (por ejemplo) la rutina de entrenamiento consume sus elementos (y algunas veces espera su finalización)

while samples_seen < samples_per_epoch: generator_output = None while not _stop.is_set(): if not data_gen_queue.empty(): generator_output = data_gen_queue.get() break else: time.sleep(wait_time)

¿Y por qué por defecto de 10? No hay ninguna razón en particular, como la mayoría de los valores predeterminados: simplemente tiene sentido, pero también podría usar valores diferentes.

Una construcción como esta sugiere que los autores pensaron en los generadores de datos costosos, lo que podría llevar su tiempo de ejecución. Por ejemplo, considere la descarga de datos a través de una red en la llamada del generador; entonces, tiene sentido guardar en caché los siguientes lotes, y descargar los siguientes en paralelo por razones de eficiencia y ser robusto a los errores de red, etc.