validation tensorflow

validation - Reemplace los monitores de validación con tf.train.SessionRunHook cuando utilice estimadores



tensorflow (3)

Estoy ejecutando un DNNClassifier, para el cual estoy monitoreando la precisión mientras entreno. Monitors.ValidationMonitor de contrib / learn ha funcionado muy bien, en mi implementación lo defino:

validation_monitor = skflow.monitors.ValidationMonitor(input_fn=lambda: input_fn(A_test, Cl2_test), eval_steps=1, every_n_steps=50)

y luego usar la llamada de:

clf.fit(input_fn=lambda: input_fn(A, Cl2), steps=1000, monitors=[validation_monitor])

dónde:

clf = tensorflow.contrib.learn.DNNClassifier(...

Esto funciona bien. Dicho esto, los monitores de validación parecen estar en desuso y una funcionalidad similar debe reemplazarse con tf.train.SessionRunHook .

Soy un novato en TensorFlow, y no me parece trivial cómo se vería esa implementación de reemplazo. Cualquier sugerencia es muy apreciada. De nuevo, necesito validar la capacitación después de un número específico de pasos. Muchas gracias de antemano.


Como desea validar el entrenamiento después de cada n_steps, tf utilizará el último punto de control guardado. Puede usar una clase personalizada de CheckpointSaverListener para agregar el paso de evaluación después de que el punto de control se guarde utilizando CheckpointSaverHook . Pasar el objeto clasificador modelo y la función de entrada de evaluación a la clase.

Referencia https://www.tensorflow.org/api_docs/python/tf/train/CheckpointSaverListener

class ExampleCheckpointSaverListener(CheckpointSaverListener): def __init(self): self.classifier = classifier self.eval_input_fn = eval_input_fn def begin(self): # You can add ops to the graph here. print(''Starting the session.'') self.your_tensor = ... def before_save(self, session, global_step_value): print(''About to write a checkpoint'') eval_op = self.classifier.evaluate(input_fn=self.eval_input_fn) print(eval_op) def after_save(self, session, global_step_value): print(''Done writing checkpoint.'') def end(self, session, global_step_value): print(''Done with the session.'') ... listener = ExampleCheckpointSaverListener(Myclassifier, eval_input_fn ) saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir, listeners=[listener]) with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):


Hay una utilidad no documentada llamada monitors.replace_monitors_with_hooks() que convierte los monitores en ganchos. El método acepta (i) una lista que puede contener monitores y ganchos y (ii) el Estimador para el cual se usarán los ganchos, y luego devuelve una lista de ganchos envolviendo un SessionRunHook alrededor de cada Monitor.

from tensorflow.contrib.learn.python.learn import monitors as monitor_lib clf = tf.estimator.Estimator(...) list_of_monitors_and_hooks = [tf.contrib.learn.monitors.ValidationMonitor(...)] hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, clf)

Esta no es realmente una solución real al problema de reemplazar por completo el ValidationMonitor; en lugar de eso, simplemente lo estamos terminando con una función no obsoleta. Sin embargo, puedo decir que esto me ha funcionado hasta el momento en que mantuvo toda la funcionalidad que necesito del ValidationMonitor (es decir, evaluar cada n pasos, dejar de usar una métrica, etc.)

Una cosa más: para usar este gancho, deberá actualizar desde un tf.contrib.learn.Estimator (que solo acepta monitores) al más completo y oficial tf.estimator.Estimator (que solo acepta ganchos). Por lo tanto, debe crear una instancia de su clasificador como tf.estimator.DNNClassifier , y entrenar utilizando su método train() lugar (que es solo un cambio de nombre de fit() ):

clf = tf.estimator.Estimator(...) ... clf.train( input_fn=... ... hooks=hooks)


tf.train.SessionRunHook una manera de extender tf.train.SessionRunHook como se sugiere.

import tensorflow as tf class ValidationHook(tf.train.SessionRunHook): def __init__(self, model_fn, params, input_fn, checkpoint_dir, every_n_secs=None, every_n_steps=None): self._iter_count = 0 self._estimator = tf.estimator.Estimator( model_fn=model_fn, params=params, model_dir=checkpoint_dir ) self._input_fn = input_fn self._timer = tf.train.SecondOrStepTimer(every_n_secs, every_n_steps) self._should_trigger = False def begin(self): self._timer.reset() self._iter_count = 0 def before_run(self, run_context): self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) def after_run(self, run_context, run_values): if self._should_trigger: self._estimator.evaluate( self._input_fn ) self._timer.update_last_triggered_step(self._iter_count) self._iter_count += 1

y lo usé como un training_hook en Estimator.train :

estimator.train(input_fn=_input_fn(...), steps=num_epochs * num_steps_per_epoch, hooks=[ValidationHook(...)])

No tiene nada de extravagante que un ValidationMonitor tenga como parar temprano y todo eso, pero esto debería ser un comienzo.