es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

¿Cuáles son los elementos de los datos en Pytorch Geometric?

Estoy estudiando sobre GNN, y para codificar:

Código de introducción de Pytorch Geometric desde Tutorial de Pytorch Geometric (https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html) en (https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial1/Tutorial1.ipynb)

import torch_geometric
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root=”tutorial1″,name=”Cora”)
data = dataset[0]
print(data)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

######(Omití mi red neuronal y train(), que no están relacionados con mi pregunta)######

def test():
model.eval()
logits, accs = model(), []
for _, mask in data(‘train_mask’, ‘val_mask’, ‘test_mask’):
pred = logits[mask].max(1)[1]
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
accs.append(acc)
return accs

Lo que me intriga es que:

for _, mask in data(‘train_mask’, ‘val_mask’, ‘test_mask’):

Porque no entiendo lo que data(‘train_mask’,’val_mask’,’test_mask’) es. El resultado es

Así que no entiendo lo que es. Leí algunas documentaciones sobre generadores, pero ¿cómo puedo ver cuáles son los elementos?

Tags:  , ,

Answer

  1. Avatar for davy.ai

    La línea for _, mask in data('train_mask', 'val_mask', 'test_mask'): utiliza el método __call__ del objeto data para generar un objeto generador que devuelve las máscaras especificadas para los conjuntos de entrenamiento, validación y prueba en los datos.

    El objeto generador es esencialmente un iterador perezoso que genera valores sobre la marcha en lugar de mantenerlos en memoria. Para ver los elementos generados por el generador, puede utilizar list() para crear una lista de los valores:

    masks = list(data('train_mask', 'val_mask', 'test_mask'))
    print(masks)
    

    Esto debería producir una lista de tres tensores booleanos para las máscaras de entrenamiento, validación y prueba, respectivamente.

Comments are closed.