Cómo agrupar en lotes una lista anidada de gráficas en PyTorch Geometric.
Actualmente estoy entrenando un modelo que es una combinación de redes neuronales gráficas y LSTM. Sin embargo, esto significa que para cada una de mis muestras de entrenamiento, necesito pasar una lista de grafos. La clase de lote actual en torch_geometric admite el agrupamiento con torch_geometric.data.Batch.from_data_list()
, pero esto solo permite un gráfico por cada punto de datos. ¿Cómo más puedo agrupar los grafos?
davy.ai
Para agrupar varios grafos para cada muestra de entrenamiento, puedes usar la clase
torch_geometric.data.DataLoader
.Primero, debes crear una clase de conjunto de datos personalizada que herede de
torch_geometric.data.Dataset
. En tu clase de conjunto de datos personalizada, necesitarás sobrescribir el método__getitem__
para devolver la lista de grafos y el objetivo correspondiente para un índice dado.A continuación, puedes crear un objeto
DataLoader
pasando tu objeto de conjunto de datos personalizado y otros argumentos relevantes como el tamaño del lote, la aleatorización y la función de agrupación (collate_fn
). En la funcióncollate_fn
, puedes usar el métodotorch_geometric.data.Batch.from_data_list()
para crear un lote de grafos para cada muestra de entrenamiento.Aquí tienes un ejemplo de implementación:
En el ejemplo anterior, hemos definido una clase de conjunto de datos personalizada que recibe una lista de grafos y una lista de objetivos. El método
__getitem__
sobrescrito devuelve los datos del grafo y el objetivo correspondiente para un índice dado.También definimos una función
collate_fn
que recibe la lista de grafos y aplicaBatch.from_data_list()
para agrupar los grafos. Además, convertimos la lista de objetivos en untorch.tensor
.Finalmente, creamos un objeto
DataLoader
pasando nuestro objeto de conjunto de datos personalizado, el tamaño del lote, la aleatorización y la función de agrupación. Luego podemos usar este objetoDataLoader
para entrenar nuestro modelo con los datos de los grafos agrupados.