train - tensorflow programming guide
El script freeze_graph de Tensorflow falla en el modelo definido con Keras (2)
Estoy intentando exportar un modelo creado y entrenado con Keras a un protobuffer que puedo cargar en un script C ++ (como en este ejemplo). He generado un archivo .pb que contiene la definición del modelo y un archivo .ckpt que contiene los datos del punto de control. Sin embargo, cuando intento fusionarlos en un único archivo con el script freeze_graph, aparece el siguiente error:
ValueError: Fetch argument ''save/restore_all'' of ''save/restore_all'' cannot be interpreted as a Tensor. ("The name ''save/restore_all'' refers to an Operation not in the graph.")
Estoy guardando el modelo así:
with tf.Session() as sess:
model = nndetector.architecture.models.vgg19((3, 50, 50))
model.load_weights(''/srv/nn/weights/scratch-vgg19.h5'')
init_op = tf.initialize_all_variables()
sess.run(init_op)
graph_def = sess.graph.as_graph_def()
tf.train.write_graph(graph_def=graph_def, logdir=''.'', name=''model.pb'', as_text=False)
saver = tf.train.Saver()
saver.save(sess, ''model.ckpt'')
nndetector.architecture.models.vgg19 ((3, 50, 50)) es simplemente un modelo similar a vgg19 definido en Keras.
Llamaré al script freeze_graph así:
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=[path-to-model.pb] --input_checkpoint=[path-to-model.ckpt] --output_graph=[output-path] --output_node_names=sigmoid --input_binary=True
Si ejecuto el script freeze_graph_test
todo funciona bien.
¿Alguien sabe lo que estoy haciendo mal?
Gracias.
Atentamente
Felipe
EDITAR
He intentado imprimir tf.train.Saver().as_saver_def().restore_op_name
que devuelve save/restore_all
.
Además, probé un ejemplo simple de tensorflow puro y aún recibo el mismo error:
a = tf.Variable(tf.constant(1), name=''a'')
b = tf.Variable(tf.constant(2), name=''b'')
add = tf.add(a, b, ''sum'')
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir=''.'', name=''simple_as_binary.pb'', as_text=False)
tf.train.Saver().save(sess, ''simple.ckpt'')
Y tampoco puedo restaurar el gráfico en python. El uso del siguiente código arroja ValueError: No variables to save
si lo ejecuto por separado de guardar el gráfico (es decir, si ValueError: No variables to save
y restauro el modelo en el mismo script, todo funciona bien).
with gfile.FastGFile(''simple_as_binary.pb'') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
tf.import_graph_def(graph_def)
saver = tf.train.Saver()
saver.restore(sess, ''simple.ckpt'')
No estoy seguro si los dos problemas están relacionados, o si simplemente no estoy restaurando el modelo correctamente en Python.
El problema es el orden de estas dos líneas en su programa original:
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir=''.'', name=''simple_as_binary.pb'', as_text=False)
tf.train.Saver().save(sess, ''simple.ckpt'')
Llamar a tf.train.Saver()
agrega un conjunto de nodos al gráfico, incluido uno llamado "save/restore_all"
. Sin embargo, este programa lo llama después de escribir el gráfico, por lo que el archivo que pasa a freeze_graph.py
no contiene esos nodos, que son necesarios para hacer la reescritura.
Invertir las dos líneas debe hacer que el script funcione según lo previsto:
tf.train.Saver().save(sess, ''simple.ckpt'')
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir=''.'', name=''simple_as_binary.pb'', as_text=False)
Entonces, lo tengo funcionando. Más o menos
Al usar tensorflow.python.client.graph_util.convert_variables_to_constants
directamente en lugar de guardar primero GraphDef
y un punto de control en el disco y luego usar la herramienta / script freeze_graph
, he podido guardar un GraphDef
contiene tanto la definición del gráfico como las variables convertidas en constantes.
EDITAR
mrry actualizó su respuesta, lo que resolvió mi problema de que freeze_graph no funcionaba, pero también dejaré esta respuesta, en caso de que alguien más pueda encontrarla útil.