Cómo obtener tokens generados en el paso de entrenamiento de T5 para utilizar métricas definidas por el usuario?
Estoy afinando T5 para la generación de respuestas a preguntas y quiero agregar medidas adicionales (por ejemplo, BLEU, ROUGE) para las respuestas generadas, además de la función de pérdida.
Para eso, creo que sería necesario obtener los tokens generados (respuestas) en cada training_step. Sin embargo, después de leer el código fuente, aún no tengo idea de cómo hacerlo.
A continuación, dejo un fragmento de mi código. Puedo extraer la pérdida de salida (output.loss) y los logits de salida (output.logits), pero no encontré una forma de obtener los tokens generados para usar métricas de evaluación adicionales.
Gracias de antemano.
class MyQAModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
def forward(self, input_ids, attention_mask, labels=None):
output = self.model(
input_ids,
attention_mask=attention_mask,
labels=labels)
return output.loss, output.logits
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log("train_loss", loss, prog_bar=True, logger=True)
return {"loss": loss, "predictions":outputs, "labels": labels}
davy.ai
Para obtener los tokens generados durante el ajuste fino (fine-tuning) de T5 para la generación de respuestas a preguntas, puedes usar el método
generate
proporcionado por la biblioteca transformers de Hugging Face. En particular, puedes utilizar el métodomodel.generate
para generar respuestas dada la pregunta de entrada y el contexto.Aquí tienes un ejemplo de cómo modificar tu código para obtener los tokens generados:
En el código modificado, el método
generate_answer
tomainput_ids
yattention_mask
como entradas, y devuelve los tokens de respuesta generados utilizando el métodomodel.generate
. A continuación, el métodotraining_step
llama agenerate_answer
para cada elemento del lote (batch) para obtener los tokens de respuesta generados.Puedes utilizar los tokens de respuesta generados junto con las etiquetas de verdad fundamental (ground truth labels) para calcular métricas de evaluación adicionales, como la puntuación BLEU o ROUGE, como se describe en tu pregunta. En el código modificado anteriormente, hemos incluido un ejemplo de cómo calcular la puntuación BLEU utilizando la función
sentence_bleu
proporcionada por la bibliotecanltk
.