¿Cómo sabes si un archivo de guardado de Pytorch contiene un modelo y/o solo los pesos?
Soy relativamente nuevo en PyTorch y esto podría ser un problema de versión, pero veo que se utiliza torch.load y torch.load_state_dict, y en ambos casos la extensión del archivo es comúnmente “.pth”.
Los modelos que he creado puedo guardarlos y cargarlos usando torch.save y torch.load, y luego llamar a model.eval().
Tengo otro archivo modelo del cual estoy bastante seguro de que es solo el diccionario de estado, porque model.eval() falla después de cargarlo.
¿Cómo puedo inspeccionar el archivo y saber si contiene un modelo completo?
Muchas gracias.
davy.ai
Si estás viendo la extensión de archivo “.pth” siendo utilizada tanto para
torch.load
como paratorch.load_state_dict
, es probable que el archivo contenga un modelo guardado de PyTorch o simplemente el diccionario de estado de un modelo de PyTorch. Sin embargo, como mencionaste quemodel.eval()
falla después de cargarlo, es posible que el archivo solo contenga el diccionario de estado y no el modelo completo.Para inspeccionar el contenido del archivo y saber si contiene el modelo completo o solo el diccionario de estado, puedes usar el siguiente código:
En este código, primero cargamos el objeto desde el archivo usando
torch.load
. Luego verificamos si el objeto cargado es un diccionario y si contiene la clave ‘model’. Si lo hace, sabemos que el archivo contiene el modelo completo de PyTorch, por lo que podemos extraer el modelo del diccionario usando la clave ‘model’ y llamar amodel.eval()
en él. Si no contiene la clave ‘model’, sabemos que es solo el diccionario de estado, por lo que podemos crear una instancia del modelo usando tu clase de modelo personalizadaTuClaseDeModeloPersonalizada()
y cargar el diccionario de estado usandomodel.load_state_dict(loaded_dict)
antes de llamar amodel.eval()
en él.¡Espero que esto te ayude!