es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

¿Cómo retropropaga un gradiente a través de muestras aleatorias?

Estoy aprendiendo sobre políticas de gradientes y estoy teniendo dificultades para entender cómo el gradiente pasa a través de una operación aleatoria. Desde aquí: “No es posible retropropagar directamente a través de muestras aleatorias. Sin embargo, existen dos métodos principales para crear funciones sustitutas que pueden retropropagarse a través”.

Tienen un ejemplo de la “función de puntuación”:

probs = policy_network(state)

Tenga en cuenta que esto es equivalente a lo que solía llamarse multinomial

m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

Lo cual intenté crear un ejemplo de:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import matplotlib.pyplot as plt
from tqdm import tqdm

softplus = torch.nn.Softplus()

class Model_RL(nn.Module):
def init(self):
super(Model_RL, self).init()
self.fc1 = nn.Linear(1, 20)
self.fc2 = nn.Linear(20, 30)
self.fc3 = nn.Linear(30, 2)

def forward(self, x):
    x1 = self.fc1(x)
    x = torch.relu(x1)
    x2 = self.fc2(x)
    x = torch.relu(x2)
    x3 = softplus(self.fc3(x))
    return x3, x2, x1

básico

net_RL = Model_RL()

features = torch.tensor([1.0])
x = torch.tensor([1.0])
y = torch.tensor(3.0)

baseline = 0
baseline_lr = 0.1

epochs = 3

opt_RL = optim.Adam(net_RL.parameters(), lr=1e-3)
losses = []
xs = []
for _ in tqdm(range(epochs)):
out_RL = net_RL(x)
mu, std = out_RL[0]
dist = Normal(mu, std)
print(dist)
a = dist.sample()
log_p = dist.log_prob(a)

out = features * a
reward = -torch.square((y - out))
baseline = (1-baseline_lr)*baseline + baseline_lr*reward

loss = -(reward-baseline)*log_p

opt_RL.zero_grad()
loss.backward()
opt_RL.step()
losses.append(loss.item())

Esto parece funcionar mágicamente bien, lo cual de nuevo, no entiendo cómo el gradiente pasa a través de la operación aleatoria (pero de alguna manera lo hace).

Ahora, dado que el gradiente no puede fluir a través de la operación aleatoria, intenté reemplazar mu, std = out_RL[0] con mu, std = out_RL[0].detach() y eso provocó el error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn.

Si el gradiente no pasa a través de la operación aleatoria, no entiendo por qué desvincular un tensor antes de la operación importa.

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    La razón por la cual el gradiente no puede pasar directamente a través de una operación aleatoria es porque la salida de la operación no es determinista y se basa en un muestreo probabilístico. Esto significa que la función de pérdida resultante también es una variable aleatoria y no puede ser optimizada directamente utilizando el descenso del gradiente.

    Para superar esto, se utilizan funciones sustitutas para aproximar la verdadera función de pérdida que se puede propagar hacia atrás. En el ejemplo proporcionado, la “función de puntuación” se utiliza como una función sustituta para estimar el gradiente esperado de la función de pérdida con respecto a los parámetros del modelo.

    En el código proporcionado, la salida out_RL[0] se utiliza para calcular la media y la desviación estándar de una distribución normal, que luego se utiliza para muestrear una acción a utilizando la distribución Normal. Dado que la distribución Normal es una función determinista de la media y la desviación estándar, los gradientes pueden pasar a través de esta operación y, por lo tanto, se pueden utilizar para optimizar los parámetros del modelo.

    Sin embargo, al desconectar un tensor antes de la operación, se independiza de su historial anterior y, por lo tanto, se elimina del gráfico computacional. Por lo tanto, los gradientes no pueden fluir a través de esta operación, ya que no hay registro de su historial de cálculo.

    Por lo tanto, es esencial definir funciones sustitutas que se puedan propagar hacia atrás y asegurarse de que todas las operaciones que deben tener gradientes se mantengan dentro del gráfico computacional.

Comments are closed.