es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

PyTorch: ¿Cómo pasar el gradiente a través de index_add?

Estoy tratando de realizar una retropropagación a través de una operación Tensor.index_add, pero se produce un error cuando index_add transmite vectores. ¿Es esto intencional o es un error en PyTorch?

El siguiente código NO arroja un error:

m = torch.zeros(4)
m.requires_grad = True

i = [0, 1, 2, 3]
v = [6, 7, 8, 9]

i = torch.tensor(i).long()
v = torch.tensor(v).float()
v.requires_grad = True

print(m)
m = m.index_add(0, i, v)
print(m)
m.sum().backward()
print(v.grad)
print("DONE")

El siguiente código arroja un error:

m = torch.zeros(4, 4)
m.requires_grad = True

i = [0, 1, 2, 3]
v = [6, 7, 8, 9]

i = torch.tensor(i).long()
v = torch.tensor(v).float()
v.requires_grad = True

print(m)
m = m.index_add(0, i, v)
print(m)
m.sum().backward()
print(v.grad)
print("DONE")

Ese código, cuando se ejecuta, da la siguiente salida y error:

tensor([0., 0., 0., 0.], requires_grad=True)
tensor([6., 7., 8., 9.], grad_fn=<indexaddbackward>)
tensor([1., 1., 1., 1.])
DONE
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], requires_grad=True)
tensor([[6., 6., 6., 6.],
[7., 7., 7., 7.],
[8., 8., 8., 8.],
[9., 9., 9., 9.]], grad_fn=<indexaddbackward>)
ERROR: RuntimeError: expand(torch.FloatTensor{[4, 4]}, size=[4]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (2)

Tags:  , ,

Answer

  1. Avatar for davy.ai

    Es un comportamiento intencional que la operación Tensor.index_add se rompa al transmitir vectores. Esto se debe a que la transmisión puede causar errores sutiles en el cálculo del gradiente, y PyTorch tiene como objetivo proporcionar una transmisión explícita a través de la API de transmisión en su lugar. Por lo tanto, al usar index_add, el tensor y el índice deben tener la misma forma. El error en el código proporcionado se debe al hecho de que el tensor m tiene forma (4,4), pero i y v tienen forma (4,) y (4,) respectivamente, lo que provoca la transmisión al llamar a index_add. Para evitar este error, puedes cambiar la forma de i y v a (4,1) utilizando torch.unsqueeze, o utilizar la transmisión explícita en su lugar.

Comments are closed.