what protocol protobuf google buffers protocol-buffers tensorflow

protocol buffers - protocol - ¿Hay algún ejemplo sobre cómo generar archivos protobuf que contengan gráficos entrenados de TensorFlow?



google:: protobuf (6)

Acabo de encontrar esta publicación y fue muy útil gracias! También voy con el método de @ Mostafa, aunque mi código C ++ es un poco diferente:

std::vector<string> names; int node_count = graph.node_size(); cout << node_count << " nodes in graph" << endl; // iterate all nodes for(int i=0; i<node_count; i++) { auto n = graph.node(i); cout << i << ":" << n.name() << endl; // if name contains "var_hack", add to vector if(n.name().find("var_hack") != std::string::npos) { names.push_back(n.name()); cout << "......bang" << endl; } } session.Run({}, names, {}, &outputs);

Nota: uso "var_hack" como mi nombre de variable en python

Estoy viendo el ejemplo de Google sobre cómo implementar y usar un gráfico (modelo) pretensado de Tensorflow en Android. Este ejemplo utiliza un archivo .pb en:

https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

que es un enlace a un archivo que se descarga automáticamente .

El ejemplo muestra cómo cargar el archivo .pb en una sesión de Tensorflow y usarlo para realizar la clasificación, pero no parece mencionar cómo generar dicho archivo .pb , después de entrenar un gráfico (por ejemplo, en Python).

¿Hay algún ejemplo sobre cómo hacer eso?


Alternativamente a mi respuesta anterior usando freeze_graph() , que solo es bueno si lo llamas como un script, hay una función muy agradable que hará todo el trabajo pesado por ti y es adecuado para ser llamado desde tu código de entrenamiento de modelo normal.

convert_variables_to_constants() hace dos cosas:

  • Congela los pesos reemplazando variables con constantes.
  • Elimina nodos que no están relacionados con la predicción de avance

Suponiendo que sess es su tf.Session() y "output" es el nombre de su nodo de predicción, el siguiente código serializará su gráfico mínimo tanto en protobuf textual como binario.

from tensorflow.python.framework.graph_util import convert_variables_to_constants minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ["output"]) tf.train.write_graph(minimal_graph, ''.'', ''minimal_graph.proto'', as_text=False) tf.train.write_graph(minimal_graph, ''.'', ''minimal_graph.txt'', as_text=True)


Aquí hay otra versión de la respuesta de @ Mostafa. Una forma algo más limpia de ejecutar las operaciones tf.assign es almacenarlas en un grupo tf.group . Aquí está mi código de Python:

ops = [] for v in tf.trainable_variables(): vc = tf.constant(v.eval()) ops.append(tf.assign(v, vc)); tf.group(*ops, name="assign_trained_variables")

Y en C ++:

std::vector<tensorflow::Tensor> tmp; status = session.Run({}, {}, { "assign_trained_variables" }, &tmp); if (!status.ok()) { // Handle error }

De esta manera, solo tiene una operación con nombre para ejecutar en el lado de C ++, por lo que no tiene que perder el tiempo iterando sobre los nodos.


Encontré una función freeze_graph.py en la base de código de Tensorflow que podría ser útil al hacer esto. Por lo que entiendo, intercambia variables con constantes antes de serializar GraphDef y, por lo tanto, cuando carga este gráfico desde C ++, ya no tiene variables que deben configurarse, y puede usarlo directamente para las predicciones.

También hay una test y alguna descripción en la Guide .

Esta parece ser la opción más limpia aquí.


No pude descubrir cómo implementar el método descrito por mrry. Pero aquí cómo lo resolví. No estoy seguro de si esa es la mejor manera de resolver el problema, pero al menos lo resuelve.

Como write_graph también puede almacenar los valores de las constantes, agregué el siguiente código al python justo antes de escribir el gráfico con la función write_graph:

for v in tf.trainable_variables(): vc = tf.constant(v.eval()) tf.assign(v, vc, name="assign_variables")

Esto crea constantes que almacenan los valores de las variables después de ser entrenados y luego crean tensores " asignar_variables " para asignarlos a las variables. Ahora, cuando llame a write_graph, almacenará los valores de las variables en el archivo en forma de constantes.

La única parte restante es llamar a estos tensores " asignar_variables " en el código c para asegurarse de que sus variables estén asignadas con los valores constantes que están almacenados en el archivo. Aquí hay una forma de hacerlo:

Status status = NewSession(SessionOptions(), &session); std::vector<tensorflow::Tensor> outputs; char name[100]; for(int i = 0;status.ok(); i++) { if (i==0) sprintf(name, "assign_variables"); else sprintf(name, "assign_variables_%d", i); status = session->Run({}, {name}, {}, &outputs); }


EDITAR: El script freeze_graph.py , que es parte del repositorio TensorFlow, ahora sirve como una herramienta que genera un búfer de protocolo que representa un modelo entrenado "congelado", a partir de un TensorFlow GraphDef existente y un punto de control guardado. Utiliza los mismos pasos que se describen a continuación, pero es mucho más fácil de usar.

Actualmente el proceso no está muy bien documentado (y sujeto a refinamiento), pero los pasos aproximados son los siguientes:

  1. Construye y entrena tu modelo como un tf.Graph llamado g_1 .
  2. Obtenga los valores finales de cada una de las variables y almacénelos como matrices numpy (usando Session.run() ).
  3. En un nuevo tf.Graph llamado g_2 , cree los tensores tf.constant() para cada una de las variables, utilizando el valor de la matriz numpy correspondiente obtenida en el paso 2.
  4. Use tf.import_graph_def() para copiar nodos de g_1 en g_2 , y use el argumento input_map para reemplazar cada variable en g_1 con los tensores tf.constant() correspondientes creados en el paso 3. También puede usar input_map para especificar un nuevo tensor de entrada (por ejemplo, reemplazar una tubería de entrada con un tf.placeholder() ). Use el argumento return_elements para especificar el nombre del tensor de salida pronosticado.

  5. Llame a g_2.as_graph_def() para obtener una representación del búfer de protocolo del gráfico.

( NOTA: El gráfico generado tendrá nodos adicionales en el gráfico para entrenamiento. Aunque no es parte de la API pública, es posible que desee utilizar la función interna graph_util.extract_sub_graph() para quitar estos nodos del gráfico).