Problemas al cargar la Red Neuronal en un nuevo archivo de Python (PyTorch)
Estoy teniendo problemas para cargar una red neuronal que ha sido entrenada en un nuevo archivo. He guardado la red neuronal en Model.py con
torch.save(model.state_dict(), “model.pth”)
Ahora, en un nuevo archivo (Test.py) tengo el siguiente código:
from Model import NeuralNetwork
import torch
model = NeuralNetwork()
model.load_state_dict(torch.load(model.pth))
model.eval()
Sin embargo, cuando ejecuto este archivo (Test.py), la red neuronal parece estar volviendo a entrenarse (puedo ver que se imprimen nuevamente los epochs). Estoy bastante seguro de que esto no debería estar sucediendo. ¡Si alguien pudiera por favor responderme sobre cuál es el problema, sería increíble! ¡Gracias de antemano!
davy.ai
El problema aquí parece estar en la línea:
model.load_state_dict(torch.load(model.pth))
En lugar de pasar la ruta al archivo del modelo guardado, estás pasando la cadena “model.pth” como argumento a
torch.load()
. Debes pasar la ruta real del archivo del modelo guardado, de esta manera:model.load_state_dict(torch.load("model.pth"))
Esto debería cargar los pesos del modelo guardado en la instancia
NeuralNetwork
enTest.py
, sin entrenar la red nuevamente.