es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Guardar y cargar vocabulario y peso para utilizar torchtext en la inferencia.

He seguido este ejemplo para construir un clasificador de texto usando torchtext.

Después de terminar el entrenamiento, he guardado el vocabulario y los pesos, pero cuando los he cargado de nuevo para hacer una inferencia con los datos de validación, obtuve cero de precisión.

Encontré una pregunta similar en https://stackoverflow.com/questions/62744998/using-torchtext-for-inference pero no resuelve mi problema.

def save_vocab(vocab):
    import pickle
    output = open('vocab.pkl', 'wb')
    pickle.dump(vocab, output)
    output.close()

save_vocab(TEXT.vocab)
torch.save(model.state_dict(), 'saved_weights.pt')

def read_vocab(path):
    # lee el archivo pkl del vocabulario 
    import pickle
    pkl_file = open(path, 'rb')
    vocab = pickle.load(pkl_file)
    pkl_file.close()
    return vocab

vocab = read_vocab('vocab.pkl')
TEXT = data.Field(batch_first=True, include_lengths=True)
LABEL = data.LabelField(dtype=torch.float, batch_first=True)

TEXT.vocab = vocab
LABEL.vocab = vocab

pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

# cargar los pesos
path = 'saved_weights.pt'
model.load_state_dict(torch.load(path))
model.eval()

model = model.cuda()
acc = evaluate(model, valid_iterator)

Gracias.

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    Parece que el problema podría estar relacionado con la carga del vocabulario y los pesos. Mientras que la función read_vocab se está utilizando para leer el archivo pickle del vocabulario, no está claro si TEXT.vocab y LABEL.vocab se están configurando correctamente.

    Una posible solución podría ser recrear los campos TEXT y LABEL y establecer el vocabulario basado en el archivo pickle cargado. Aquí tienes un ejemplo:

    def read_vocab(path):
        # leer archivo pkl del vocabulario
        import pickle
        pkl_file = open(path, 'rb')
        vocab = pickle.load(pkl_file)
        pkl_file.close()
        return vocab
    
    vocab = read_vocab('vocab.pkl')
    
    TEXT = data.Field(batch_first=True, include_lengths=True)
    LABEL = data.LabelField(dtype=torch.float, batch_first=True)
    
    # configurar los vocabularios
    TEXT.vocab = vocab
    LABEL.vocab = vocab
    
    # crear el iterador con los conjuntos de datos
    _, valid_data = datasets.IMDB.splits(TEXT, LABEL)
    
    # establecer el tamaño del lote (BATCH_SIZE) y el dispositivo (device)
    BATCH_SIZE = 64
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    valid_iterator = data.BucketIterator(valid_data, batch_size=BATCH_SIZE, device=device, shuffle=False)
    
    # crear el modelo y cargar los pesos
    model = LSTMClassifier(len(TEXT.vocab), EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT)
    
    path = 'saved_weights.pt'
    model.load_state_dict(torch.load(path))
    model.eval()
    
    # evaluar el modelo
    acc = evaluate(model, valid_iterator)
    

    En este ejemplo, los campos TEXT y LABEL se recrean y los vocabularios se configuran después de cargar el archivo pickle. Se crea valid_data con estos campos y se crea el iterador utilizando estos datos. El resto del código debería ser igual que en el ejemplo original.

    ¡Espero que esto ayude a resolver el problema!

Comments are closed.