protocol buffers - protocol - ¿Hay algún ejemplo sobre cómo generar archivos protobuf que contengan gráficos entrenados de TensorFlow?
google:: protobuf (6)
Acabo de encontrar esta publicación y fue muy útil gracias! También voy con el método de @ Mostafa, aunque mi código C ++ es un poco diferente:
std::vector<string> names;
int node_count = graph.node_size();
cout << node_count << " nodes in graph" << endl;
// iterate all nodes
for(int i=0; i<node_count; i++) {
auto n = graph.node(i);
cout << i << ":" << n.name() << endl;
// if name contains "var_hack", add to vector
if(n.name().find("var_hack") != std::string::npos) {
names.push_back(n.name());
cout << "......bang" << endl;
}
}
session.Run({}, names, {}, &outputs);
Nota: uso "var_hack" como mi nombre de variable en python
Estoy viendo
el ejemplo de Google
sobre cómo implementar y usar un gráfico (modelo) pretensado de Tensorflow en Android.
Este ejemplo utiliza un archivo
.pb
en:
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
que es un enlace a un archivo que se descarga automáticamente .
El ejemplo muestra cómo cargar el archivo
.pb
en una sesión de Tensorflow y usarlo para realizar la clasificación, pero no parece mencionar cómo generar dicho archivo
.pb
, después de entrenar un gráfico (por ejemplo, en Python).
¿Hay algún ejemplo sobre cómo hacer eso?
Alternativamente a mi respuesta anterior usando
freeze_graph()
, que solo es bueno si lo llamas como un script, hay una función muy agradable que hará todo el trabajo pesado por ti y es adecuado para ser llamado desde tu código de entrenamiento de modelo normal.
convert_variables_to_constants()
hace dos cosas:
- Congela los pesos reemplazando variables con constantes.
- Elimina nodos que no están relacionados con la predicción de avance
Suponiendo que
sess
es su
tf.Session()
y
"output"
es el nombre de su nodo de predicción, el siguiente código serializará su gráfico mínimo tanto en protobuf textual como binario.
from tensorflow.python.framework.graph_util import convert_variables_to_constants
minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ["output"])
tf.train.write_graph(minimal_graph, ''.'', ''minimal_graph.proto'', as_text=False)
tf.train.write_graph(minimal_graph, ''.'', ''minimal_graph.txt'', as_text=True)
Aquí hay otra versión de la respuesta de @ Mostafa.
Una forma algo más limpia de ejecutar las operaciones
tf.assign
es almacenarlas en un grupo
tf.group
.
Aquí está mi código de Python:
ops = []
for v in tf.trainable_variables():
vc = tf.constant(v.eval())
ops.append(tf.assign(v, vc));
tf.group(*ops, name="assign_trained_variables")
Y en C ++:
std::vector<tensorflow::Tensor> tmp;
status = session.Run({}, {}, { "assign_trained_variables" }, &tmp);
if (!status.ok()) {
// Handle error
}
De esta manera, solo tiene una operación con nombre para ejecutar en el lado de C ++, por lo que no tiene que perder el tiempo iterando sobre los nodos.
Encontré una función freeze_graph.py en la base de código de Tensorflow que podría ser útil al hacer esto. Por lo que entiendo, intercambia variables con constantes antes de serializar GraphDef y, por lo tanto, cuando carga este gráfico desde C ++, ya no tiene variables que deben configurarse, y puede usarlo directamente para las predicciones.
También hay una test y alguna descripción en la Guide .
Esta parece ser la opción más limpia aquí.
No pude descubrir cómo implementar el método descrito por mrry. Pero aquí cómo lo resolví. No estoy seguro de si esa es la mejor manera de resolver el problema, pero al menos lo resuelve.
Como write_graph también puede almacenar los valores de las constantes, agregué el siguiente código al python justo antes de escribir el gráfico con la función write_graph:
for v in tf.trainable_variables():
vc = tf.constant(v.eval())
tf.assign(v, vc, name="assign_variables")
Esto crea constantes que almacenan los valores de las variables después de ser entrenados y luego crean tensores " asignar_variables " para asignarlos a las variables. Ahora, cuando llame a write_graph, almacenará los valores de las variables en el archivo en forma de constantes.
La única parte restante es llamar a estos tensores " asignar_variables " en el código c para asegurarse de que sus variables estén asignadas con los valores constantes que están almacenados en el archivo. Aquí hay una forma de hacerlo:
Status status = NewSession(SessionOptions(), &session);
std::vector<tensorflow::Tensor> outputs;
char name[100];
for(int i = 0;status.ok(); i++) {
if (i==0)
sprintf(name, "assign_variables");
else
sprintf(name, "assign_variables_%d", i);
status = session->Run({}, {name}, {}, &outputs);
}
EDITAR:
El script
freeze_graph.py
, que es parte del repositorio TensorFlow, ahora sirve como una herramienta que genera un búfer de protocolo que representa un modelo entrenado "congelado", a partir de un TensorFlow
GraphDef
existente y un punto de control guardado.
Utiliza los mismos pasos que se describen a continuación, pero es mucho más fácil de usar.
Actualmente el proceso no está muy bien documentado (y sujeto a refinamiento), pero los pasos aproximados son los siguientes:
-
Construye y entrena tu modelo como un
tf.Graph
llamadog_1
. -
Obtenga los valores finales de cada una de las variables y almacénelos como matrices numpy (usando
Session.run()
). -
En un nuevo
tf.Graph
llamadog_2
, cree los tensorestf.constant()
para cada una de las variables, utilizando el valor de la matriz numpy correspondiente obtenida en el paso 2. -
Use
tf.import_graph_def()
para copiar nodos deg_1
eng_2
, y use el argumentoinput_map
para reemplazar cada variable eng_1
con los tensorestf.constant()
correspondientes creados en el paso 3. También puede usarinput_map
para especificar un nuevo tensor de entrada (por ejemplo, reemplazar una tubería de entrada con untf.placeholder()
). Use el argumentoreturn_elements
para especificar el nombre del tensor de salida pronosticado. -
Llame a
g_2.as_graph_def()
para obtener una representación del búfer de protocolo del gráfico.
(
NOTA:
El gráfico generado tendrá nodos adicionales en el gráfico para entrenamiento. Aunque no es parte de la API pública, es posible que desee utilizar la función interna
graph_util.extract_sub_graph()
para quitar estos nodos del gráfico).