es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

PyTorch – (Modelo de implementación) convertir de .pth a onnx

Quiero convertir mi modelo de pyTorch (.pth) a ONNX.
Este es mi código.

import torch
import torchvision
dummy_input = torch.randn(1,3,224,224,device = "cuda")
model = torch.load("/home/selfsup/moco/r50_v2/epoch_200.pth")
torch.onnx.export(model,dummy_input,"LENet.onnx",verbose = True)

Sin embargo, obtengo un error que dice lo siguiente.

Traceback (most recent call last):
  File "tengine_convert.py", line 5, in <module>
    torch.onnx.export(model,dummy_input,"LENet.onnx",verbose = True)
  File "/home/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 320, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 111, in export
    custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
  File "/home/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 709, in _export
    with select_model_mode_for_export(model, training):
  File "/home/anaconda3/lib/python3.6/contextlib.py", line 81, in __enter__
    return next(self.gen)
  File "/home/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 42, in select_model_mode_for_export
    is_originally_training = model.training
AttributeError: 'dict' object has no attribute 'training'

¿Por qué sucede esto?

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    El error ocurre porque el modelo cargado utilizando torch.load() devuelve un diccionario en lugar de una instancia de modelo de PyTorch. Por lo tanto, al intentar exportarlo al formato ONNX, la función select_model_mode_for_export() espera una instancia de modelo de PyTorch con un atributo training, el cual no está presente en el diccionario.

    Para solucionar este problema, debes cargar una instancia de modelo de PyTorch en lugar de un diccionario. Puedes hacer esto utilizando torch.load() con el parámetro map_location de la siguiente manera:

    model = torch.load('/home/selfsup/moco/r50_v2/epoch_200.pth', map_location=torch.device('cpu'))
    

    Esto carga el modelo en la CPU, incluso si originalmente se guardó en la GPU. Luego, puedes exportar el modelo al formato ONNX utilizando la función torch.onnx.export().

Comments are closed.