validation - Reemplace los monitores de validación con tf.train.SessionRunHook cuando utilice estimadores
tensorflow (3)
Estoy ejecutando un DNNClassifier, para el cual estoy monitoreando la precisión mientras entreno. Monitors.ValidationMonitor de contrib / learn ha funcionado muy bien, en mi implementación lo defino:
validation_monitor = skflow.monitors.ValidationMonitor(input_fn=lambda: input_fn(A_test, Cl2_test), eval_steps=1, every_n_steps=50)
y luego usar la llamada de:
clf.fit(input_fn=lambda: input_fn(A, Cl2),
steps=1000, monitors=[validation_monitor])
dónde:
clf = tensorflow.contrib.learn.DNNClassifier(...
Esto funciona bien. Dicho esto, los monitores de validación parecen estar en desuso y una funcionalidad similar debe reemplazarse con tf.train.SessionRunHook
.
Soy un novato en TensorFlow, y no me parece trivial cómo se vería esa implementación de reemplazo. Cualquier sugerencia es muy apreciada. De nuevo, necesito validar la capacitación después de un número específico de pasos. Muchas gracias de antemano.
Como desea validar el entrenamiento después de cada n_steps, tf utilizará el último punto de control guardado. Puede usar una clase personalizada de CheckpointSaverListener
para agregar el paso de evaluación después de que el punto de control se guarde utilizando CheckpointSaverHook
. Pasar el objeto clasificador modelo y la función de entrada de evaluación a la clase.
Referencia https://www.tensorflow.org/api_docs/python/tf/train/CheckpointSaverListener
class ExampleCheckpointSaverListener(CheckpointSaverListener):
def __init(self):
self.classifier = classifier
self.eval_input_fn = eval_input_fn
def begin(self):
# You can add ops to the graph here.
print(''Starting the session.'')
self.your_tensor = ...
def before_save(self, session, global_step_value):
print(''About to write a checkpoint'')
eval_op = self.classifier.evaluate(input_fn=self.eval_input_fn)
print(eval_op)
def after_save(self, session, global_step_value):
print(''Done writing checkpoint.'')
def end(self, session, global_step_value):
print(''Done with the session.'')
...
listener = ExampleCheckpointSaverListener(Myclassifier, eval_input_fn )
saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
Hay una utilidad no documentada llamada monitors.replace_monitors_with_hooks()
que convierte los monitores en ganchos. El método acepta (i) una lista que puede contener monitores y ganchos y (ii) el Estimador para el cual se usarán los ganchos, y luego devuelve una lista de ganchos envolviendo un SessionRunHook alrededor de cada Monitor.
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
clf = tf.estimator.Estimator(...)
list_of_monitors_and_hooks = [tf.contrib.learn.monitors.ValidationMonitor(...)]
hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, clf)
Esta no es realmente una solución real al problema de reemplazar por completo el ValidationMonitor; en lugar de eso, simplemente lo estamos terminando con una función no obsoleta. Sin embargo, puedo decir que esto me ha funcionado hasta el momento en que mantuvo toda la funcionalidad que necesito del ValidationMonitor (es decir, evaluar cada n pasos, dejar de usar una métrica, etc.)
Una cosa más: para usar este gancho, deberá actualizar desde un tf.contrib.learn.Estimator
(que solo acepta monitores) al más completo y oficial tf.estimator.Estimator
(que solo acepta ganchos). Por lo tanto, debe crear una instancia de su clasificador como tf.estimator.DNNClassifier
, y entrenar utilizando su método train()
lugar (que es solo un cambio de nombre de fit()
):
clf = tf.estimator.Estimator(...)
...
clf.train(
input_fn=...
...
hooks=hooks)
tf.train.SessionRunHook
una manera de extender tf.train.SessionRunHook
como se sugiere.
import tensorflow as tf
class ValidationHook(tf.train.SessionRunHook):
def __init__(self, model_fn, params, input_fn, checkpoint_dir,
every_n_secs=None, every_n_steps=None):
self._iter_count = 0
self._estimator = tf.estimator.Estimator(
model_fn=model_fn,
params=params,
model_dir=checkpoint_dir
)
self._input_fn = input_fn
self._timer = tf.train.SecondOrStepTimer(every_n_secs, every_n_steps)
self._should_trigger = False
def begin(self):
self._timer.reset()
self._iter_count = 0
def before_run(self, run_context):
self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
def after_run(self, run_context, run_values):
if self._should_trigger:
self._estimator.evaluate(
self._input_fn
)
self._timer.update_last_triggered_step(self._iter_count)
self._iter_count += 1
y lo usé como un training_hook
en Estimator.train
:
estimator.train(input_fn=_input_fn(...),
steps=num_epochs * num_steps_per_epoch,
hooks=[ValidationHook(...)])
No tiene nada de extravagante que un ValidationMonitor
tenga como parar temprano y todo eso, pero esto debería ser un comienzo.