es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Cómo obtener la pendiente de salida con respecto a la entrada

Tengo un problema para obtener el gradiente de salida de la entrada.
Es un modelo simple de MNIST.

para num, (sample_img, sample_label) en enumerate(mnist_test):
    si num == 1:
        break

    sample_img = sample_img.to(device)
    sample_img.requires_grad = True
    prediccion = modelo(sample_img.unsqueeze(dim=0))
    costo = criterio(prediccion, torch.tensor([sample_label]).to(device))
    optimizador.zero_grad()
    costo.backward()
    print(sample_label)
    print(sample_img.shape)

    plt.imshow(sample_img.detach().cpu().squeeze(), cmap='gray')
    plt.show()

print(sample_img.grad)

sample_img.grad es None

Tags:  , , , ,

Answer

  1. Avatar for davy.ai

    Parece que estás teniendo problemas para obtener el degradado de salida de entrada en un modelo simple de MNIST. En el código que proporcionaste, estás estableciendo la bandera de “requires_grad” de “sample_img” en True para rastrear el cálculo del degradado, y luego llamas a “backward()” en la función de costo. Sin embargo, cuando intentas imprimir el degradado de “sample_img” usando “print(sample_img.grad)”, encuentras que es None (Ninguno).

    Esto podría estar sucediendo porque estás ejecutando el código solo para una imagen de muestra (“sample_img.unsqueeze(dim=0)”), por lo que el degradado para esa imagen específica no se ha calculado. Para obtener el degradado de esa imagen, puedes intentar ejecutar el código para varias muestras y promediar los degradados. También puedes intentar establecer “retain_graph=True” en la llamada a la función “backward()”, lo cual mantendrá el grafo computacional para múltiples retrocesos. Alternativamente, puedes intentar usar “register_hook()” para calcular y almacenar manualmente los gradientes de “sample_img”.

    Adicionalmente, asegúrate de que el modelo esté configurado en modo “train()” ya que esto es necesario para que los gradientes se calculen.

Comments are closed.