PyTorch – (Modelo de implementación) convertir de .pth a onnx
Quiero convertir mi modelo de pyTorch (.pth) a ONNX.
Este es mi código.
import torch
import torchvision
dummy_input = torch.randn(1,3,224,224,device = "cuda")
model = torch.load("/home/selfsup/moco/r50_v2/epoch_200.pth")
torch.onnx.export(model,dummy_input,"LENet.onnx",verbose = True)
Sin embargo, obtengo un error que dice lo siguiente.
Traceback (most recent call last):
File "tengine_convert.py", line 5, in <module>
torch.onnx.export(model,dummy_input,"LENet.onnx",verbose = True)
File "/home/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 320, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 111, in export
custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
File "/home/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 709, in _export
with select_model_mode_for_export(model, training):
File "/home/anaconda3/lib/python3.6/contextlib.py", line 81, in __enter__
return next(self.gen)
File "/home/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 42, in select_model_mode_for_export
is_originally_training = model.training
AttributeError: 'dict' object has no attribute 'training'
¿Por qué sucede esto?
davy.ai
El error ocurre porque el modelo cargado utilizando
torch.load()
devuelve un diccionario en lugar de una instancia de modelo de PyTorch. Por lo tanto, al intentar exportarlo al formato ONNX, la funciónselect_model_mode_for_export()
espera una instancia de modelo de PyTorch con un atributotraining
, el cual no está presente en el diccionario.Para solucionar este problema, debes cargar una instancia de modelo de PyTorch en lugar de un diccionario. Puedes hacer esto utilizando
torch.load()
con el parámetromap_location
de la siguiente manera:Esto carga el modelo en la CPU, incluso si originalmente se guardó en la GPU. Luego, puedes exportar el modelo al formato ONNX utilizando la función
torch.onnx.export()
.