es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Pytorch: Cargue un modelo transformer directamente desde el punto de control de antorcha sin cargar pesos preinicializados.

Tengo una restricción de memoria al cargar un modelo desde un punto de control de antorcha para inferencia:

Lo que tengo es lo siguiente:

Un modelo de antorcha (xlm-roberta)
y un punto de control (xlm-roberta-checkpoint.pth)

Lo que normalmente sucede es que cargamos xlm-roberta de la siguiente manera:

modelo = AutoModel.from_pretrained('xlm-roberta-base')
punto de control = antorcha.load(PATH)
modelo.load_state_dict(punto de control['model_state_dict'])

Pero el problema surge al cargar el punto de control; el modelo pre-entrenado en sí mismo es bastante grande, por lo que tanto el punto de control como el modelo no pueden caber en la memoria y el proceso se interrumpe.

¿Existe alguna manera de cargar directamente el punto de control en la clase de modelo sin inicializar algunos pesos aleatorios/pre-entrenados de antemano?

Tags:  , ,

Answer

  1. Avatar for davy.ai

    Sí, es posible cargar directamente el punto de control en la clase del modelo sin inicializar ningún peso aleatorio. En lugar de cargar el modelo con AutoModel.from_pretrained(), podemos usar AutoModel.from_config() para inicializar un modelo vacío con la misma configuración que el modelo pre-entrenado. Luego, podemos cargar el punto de control en el modelo usando model.load_state_dict() como se muestra a continuación:

    desde los transformadores import AutoConfig, AutoModel
    
    # Cargar configuración del modelo
    config = AutoConfig.from_pretrained('xlm-roberta-base')
    
    # Inicializar modelo vacío con la misma configuración
    modelo = AutoModel.from_config(config)
    
    # Cargar punto de control en el modelo
    punto_de_control = torch.load('xlm-roberta-checkpoint.pth')
    modelo.load_state_dict(punto_de_control['model_state_dict'])
    

    De esta manera, evitamos cargar los pesos pre-entrenados en la memoria y solo cargamos los valores del punto de control necesarios en la clase de modelo vacía, lo que debería ayudar a conservar la memoria.

Comments are closed.