es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Cómo agrupar en lotes una lista anidada de gráficas en PyTorch Geometric.

Actualmente estoy entrenando un modelo que es una combinación de redes neuronales gráficas y LSTM. Sin embargo, esto significa que para cada una de mis muestras de entrenamiento, necesito pasar una lista de grafos. La clase de lote actual en torch_geometric admite el agrupamiento con torch_geometric.data.Batch.from_data_list(), pero esto solo permite un gráfico por cada punto de datos. ¿Cómo más puedo agrupar los grafos?

Tags:  , ,

Answer

  1. Avatar for davy.ai

    Para agrupar varios grafos para cada muestra de entrenamiento, puedes usar la clase torch_geometric.data.DataLoader.

    Primero, debes crear una clase de conjunto de datos personalizada que herede de torch_geometric.data.Dataset. En tu clase de conjunto de datos personalizada, necesitarás sobrescribir el método __getitem__ para devolver la lista de grafos y el objetivo correspondiente para un índice dado.

    A continuación, puedes crear un objeto DataLoader pasando tu objeto de conjunto de datos personalizado y otros argumentos relevantes como el tamaño del lote, la aleatorización y la función de agrupación (collate_fn). En la función collate_fn, puedes usar el método torch_geometric.data.Batch.from_data_list() para crear un lote de grafos para cada muestra de entrenamiento.

    Aquí tienes un ejemplo de implementación:

    import torch
    from torch_geometric.data import Dataset, DataLoader, Batch
    
    class CustomGraphDataset(Dataset):
        def __init__(self, data_list, target_list):
            self.data_list = data_list
            self.target_list = target_list
    
        def __len__(self):
            return len(self.target_list)
    
        def __getitem__(self, index):
            data = self.data_list[index]
            target = self.target_list[index]
            return data, target
    
    def collate_fn(data_list):
        graph_list, target_list = zip(*data_list)
        batch = Batch.from_data_list(graph_list)
        targets = torch.tensor(target_list)
        return batch, targets
    
    # Suponiendo que has creado tu lista de datos de entrenamiento (train_data_list) y lista de objetivos de entrenamiento (train_target_list)
    dataset = CustomGraphDataset(train_data_list, train_target_list)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    

    En el ejemplo anterior, hemos definido una clase de conjunto de datos personalizada que recibe una lista de grafos y una lista de objetivos. El método __getitem__ sobrescrito devuelve los datos del grafo y el objetivo correspondiente para un índice dado.

    También definimos una función collate_fn que recibe la lista de grafos y aplica Batch.from_data_list() para agrupar los grafos. Además, convertimos la lista de objetivos en un torch.tensor.

    Finalmente, creamos un objeto DataLoader pasando nuestro objeto de conjunto de datos personalizado, el tamaño del lote, la aleatorización y la función de agrupación. Luego podemos usar este objeto DataLoader para entrenar nuestro modelo con los datos de los grafos agrupados.

Comments are closed.