python - ¿La mejor manera de guardar un modelo entrenado en PyTorch?
serialization deep-learning (4)
Depende de lo que quieras hacer.
Caso n. ° 1: guarde el modelo para usarlo usted mismo por inferencia
: guarda el modelo, lo restaura y luego cambia el modelo al modo de evaluación.
Esto se hace porque generalmente tiene capas
BatchNorm
y
Dropout
que por defecto están en modo tren en la construcción:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
Caso # 2: Guarde el modelo para reanudar el entrenamiento más tarde : si necesita seguir entrenando el modelo que está a punto de guardar, debe guardar más que solo el modelo. También debe guardar el estado del optimizador, las épocas, la puntuación, etc. Lo haría así:
state = {
''epoch'': epoch,
''state_dict'': model.state_dict(),
''optimizer'': optimizer.state_dict(),
...
}
torch.save(state, filepath)
Para reanudar el entrenamiento, haría cosas como:
state = torch.load(filepath)
, y luego, para restaurar el estado de cada objeto individual, algo como esto:
model.load_state_dict(state[''state_dict''])
optimizer.load_state_dict(state[''optimizer''])
Como está reanudando el entrenamiento,
NO
llame a
model.eval()
una vez que restaure los estados al cargar.
Caso # 3: Modelo para ser usado por otra persona sin acceso a su código
: en Tensorflow puede crear un archivo
.pb
que defina tanto la arquitectura como los pesos del modelo.
Esto es muy útil, especialmente cuando se utiliza
Tensorflow serve
.
La forma equivalente de hacer esto en Pytorch sería:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
De esta manera todavía no es a prueba de balas y dado que Pytorch todavía está experimentando muchos cambios, no lo recomendaría.
Estaba buscando formas alternativas de guardar un modelo entrenado en PyTorch. Hasta ahora, he encontrado dos alternativas.
- torch.save() para guardar un modelo y torch.load() para cargar un modelo.
- model.state_dict() para guardar un modelo entrenado y model.load_state_dict() para cargar el modelo guardado.
Me he encontrado con esta discussion donde se recomienda el enfoque 2 sobre el enfoque 1.
Mi pregunta es, ¿por qué se prefiere el segundo enfoque? ¿Es solo porque los módulos torch.nn tienen esas dos funciones y se nos recomienda usarlos?
He encontrado esta página en su repositorio de github, solo pegaré el contenido aquí.
Enfoque recomendado para guardar un modelo
Hay dos enfoques principales para serializar y restaurar un modelo.
El primero (recomendado) guarda y carga solo los parámetros del modelo:
torch.save(the_model.state_dict(), PATH)
Entonces despúes:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
El segundo guarda y carga todo el modelo:
torch.save(the_model, PATH)
Entonces despúes:
the_model = torch.load(PATH)
Sin embargo, en este caso, los datos serializados están vinculados a las clases específicas y a la estructura de directorio exacta utilizada, por lo que pueden romperse de varias maneras cuando se usan en otros proyectos o después de algunos refactores serios.
La biblioteca pickle Python implementa protocolos binarios para serializar y deserializar un objeto Python.
Cuando
import torch
(o cuando usa PyTorch)
import pickle
por usted y no necesita llamar a
pickle.dump()
y
pickle.load()
directamente, que son los métodos para guardar y cargar el objeto.
De hecho,
torch.save()
y
torch.load()
envolverán
pickle.dump()
y
pickle.load()
por usted.
Un
state_dict
la otra respuesta mencionada merece solo unas pocas notas más.
¿Qué
state_dict
tenemos dentro de PyTorch?
En realidad hay dos
state_dict
s.
El modelo PyTorch es
torch.nn.Module
tiene
model.parameters()
llamada
model.parameters()
para obtener parámetros que se pueden aprender (w y b).
Estos parámetros que se pueden aprender, una vez establecidos al azar, se actualizarán con el tiempo a medida que aprendamos.
Los parámetros que se pueden aprender son el primer
state_dict
.
El segundo
state_dict
es el dict de estado del optimizador.
Recuerda que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje.
Pero el optimizador
state_dict
es fijo.
Nada que aprender allí.
Debido a que los objetos
state_dict
son diccionarios de Python, se pueden guardar, actualizar, alterar y restaurar fácilmente, agregando una gran modularidad a los modelos y optimizadores de PyTorch.
Creemos un modelo súper simple para explicar esto:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model''s state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "/t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer''s state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "/t", optimizer.state_dict()[var_name])
Este código generará lo siguiente:
Model''s state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer''s state_dict:
state {}
param_groups [{''lr'': 0.001, ''momentum'': 0.9, ''dampening'': 0, ''weight_decay'': 0, ''nesterov'': False, ''params'': [140695321443856, 140695321443928]}]
Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuenciales
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Tenga en cuenta que solo las capas con parámetros que se pueden aprender (capas convolucionales, capas lineales, etc.) y los búferes registrados (capas de batchnorm) tienen entradas en el estado
state_dict
.
Las cosas que no se pueden aprender pertenecen al objeto optimizador
state_dict
, que contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.
El resto de la historia es igual;
en la fase de inferencia (esta es una fase cuando usamos el modelo después del entrenamiento) para predecir;
predecimos en función de los parámetros que aprendimos.
Entonces, para la inferencia, solo necesitamos guardar los parámetros
model.state_dict()
.
torch.save(model.state_dict(), filepath)
Y para usar model.load_state_dict (torch.load (filepath)) model.eval ()
Nota: No olvide la última línea
model.eval()
esto es crucial después de cargar el modelo.
Tampoco intente guardar
torch.save(model.parameters(), filepath)
.
model.parameters()
es solo el objeto generador.
Por otro lado,
torch.save(model, filepath)
guarda el objeto del modelo en sí mismo, pero tenga en cuenta que el modelo no tiene el optimizador
state_dict
.
Verifique la otra excelente respuesta de @Jadiel de Armas para guardar la sentencia de estado del optimizador.
Una convención común de PyTorch es guardar modelos usando una extensión de archivo .pt o .pth.
Guardar / cargar todo el modelo Guardar:
path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)
Carga:
La clase de modelo debe definirse en alguna parte
model = torch.load(PATH)
model.eval()